diff --git a/config/tests.config b/config/tests.config index a5762df7d9..e9feff8472 100644 --- a/config/tests.config +++ b/config/tests.config @@ -14,4 +14,11 @@ exclude = examples/MEE-CBC examples/old examples/old/list-ddh !examples/incomple okdirs = examples/MEE-CBC [test-unit] -okdirs = tests tests/exception +okdirs = !tests +exclude = tests/tc-ko tests/exception !tests/require_test + +[test-exception] +okdirs = tests/exception + +[test-tc-ko] +kodirs = !tests/tc-ko diff --git a/doc/typeclasses-inference.md b/doc/typeclasses-inference.md new file mode 100644 index 0000000000..0c641d3f76 --- /dev/null +++ b/doc/typeclasses-inference.md @@ -0,0 +1,202 @@ +# Typeclass inference — design + +Companion to [typeclasses.md](typeclasses.md). Covers what the unifier +does when it encounters a `\`TcCtt(uid, ty, tc)` problem, why the current +single-axis approach is insufficient for multi-parameter typeclasses, +and the strategy framework that resolves this. + +--- + +## Background — `\`TcCtt` problems + +Whenever the typer needs a typeclass witness, it generates a problem of +the form + +``` +TcCtt (uid, ty, tc) +``` + +meaning "find a witness for `ty : tc`, and bind it to the witness +univar `uid`". The unifier's job is to either resolve `uid` to a +concrete `tcwitness` or report failure. + +Three things vary: + +1. **`ty`** — the carrier. Can be ground (`int`), abstract (`Tvar a`, + `Tconstr abs_p _`), or a univar (`Tunivar u`). +2. **`tc.tc_args`** — the type-class's auxiliary type parameters, for + parametric typeclasses like `('a, 'b) embed`. Each can be ground or + contain univars. +3. **The environment** — `tvtc` for `Tvar` carriers, the typeclass + declaration for `Tconstr abs_p`, and the instance database for + ground carriers. + +The current resolver is in `ecUnify.ml`, in the `\`TcCtt` arm of +`unify_core`. + +## Catalog of inference modes + +Every TcCtt problem falls into one of these shapes. Each row says what +information the resolver has and what it should produce. + +| # | Carrier `ty` | `tc.args` | Status today | Resolver action | +|----|---------------------------|--------------------------|---------------------------|----------------------------------------------------------------| +| 1 | ground | ground | works | `EcTypeClass.infer env ty tc` → `TCIConcrete` | +| 2 | ground | partly univar | partly works | `infer` already pattern-matches instance args, fills univars | +| 3 | univar | ground | **fails** (parks forever) | walk instances, find unique match by `tc.args`, unify carrier | +| 4 | univar | partly univar | parks | wait — too underdetermined to infer either side | +| 5 | `Tvar a`, `a ∈ tvtc` | any | works | walk `tvtc[a]`'s ancestors, return `TCIAbstract { Var a; .. }` | +| 6 | `Tconstr abs_p _` | any | works | walk decl's `tcs`, return `TCIAbstract { Abs abs_p; .. }` | +| 7 | ground tuple/fun | any | upstream rejects instance | (n/a) — but `subst_tcw` has a latent `assert false` | +| 8 | `Tvar a`, `a ∉ tvtc` | any | failure | error: "unconstrained type variable" | + +Modes #1, #2, #5, #6 are covered. Mode #3 is the bare-op gap. Modes #4 +and #7 are deferred (#4 has no inference to do; #7 is upstream). + +A future row would add *e.g.*: + +| ? | `Fapp` carrier (HO) | any | not designed | escape hatch / explicit tvi | + +## Why the current resolver doesn't cover Mode #3 + +The resolver's flow: + +``` +if TyUni.Suid.is_empty deps then + (* Mode #1, #2, #5, #6 *) + resolve and bind uid +else + (* Mode #3, #4 *) + for each univar in deps, register uid in byunivar map + wait for the univar to resolve +``` + +When `ty = Tunivar u`, `deps = {u}`. The resolver parks the problem. +It re-fires only when `u` is bound by some other equation. For Mode #3 +there is no such equation — the carrier's only constraint is the +typeclass itself. + +The fix is to attempt **forward inference** in this case: if `tc.args` +are ground and exactly one instance of `tc` matches, bind `u` to its +`tci_type`. + +## Strategy framework (Phase 2) + +Replace the single big `\`TcCtt` arm with a list of strategies. Each +strategy is: + +```ocaml +type tcw_strategy = { + name : string; + applicable : tcenv -> tcuni -> ty -> typeclass -> bool; + apply : EcEnv.env -> ucore -> tcuni -> ty -> typeclass + -> ucore * tcw_result; + triggers : tcw_trigger list; +} + +and tcw_result = + | Resolved of tcwitness + | Stuck (* park, retry on triggers *) + | Failed of failure_reason + | NoSuchInstance + +and tcw_trigger = + | OnUnivarResolved of tyuni (* re-fire when this tyuni binds *) + | OnTcUniResolved of tcuni (* re-fire when this tcuni binds *) +``` + +The dispatcher iterates strategies in priority order, stops on the +first non-`Stuck` result. + +Today's resolver becomes a list of strategies: + +| Priority | Strategy | Mode | +|----------|--------------------|------| +| 1 | `tvar_via_tvtc` | #5 | +| 2 | `abs_via_decl` | #6 | +| 3 | `infer_by_carrier` | #1, #2 | +| 4 *new* | `infer_by_args` | #3 | +| 5 | `defer` | #4 | + +Behaviour with strategies 1-3 + 5 is identical to today's resolver; +adding strategy 4 closes Mode #3. + +The `triggers` field is what lets us avoid the current implicit +re-seeding (which today re-pushes every parked problem at the start of +every `unify_core` call). With explicit triggers we only re-fire what +the latest binding could have made progress on. This is performance +hygiene; not strictly required for correctness. + +## By-args strategy (Phase 3) + +``` +applicable(tcenv, uid, ty, tc): + ty is Tunivar u AND + tc.args contains no univars + +apply(env, uc, uid, ty, tc): + candidates = + [ inst | inst ∈ TcInstance.get_all env, + inst.tci_instance = `General (tgp, _), + tgp.tc_name = tc.tc_name, + etyargs_match env (List.fst inst.tci_params) + ~patterns:tgp.tc_args ~etyargs:tc.tc_args + succeeds with map M ] + + match candidates: + | [] -> NoSuchInstance + | [inst, M] -> let carrier = subst M inst.tci_type in + unify env uc ty carrier ; + Resolved (TCIConcrete { path = inst_path; + etyargs = subst M inst.tci_params; + lift = 0 }) + | _ :: _ :: _-> Stuck (* multiple matches; later info may decide *) +``` + +**Soundness:** we only commit when the match is unique. With multiple +matches we stay parked; if no further constraint disambiguates, the +final close-time check raises an "ambiguous TC instance" error +(distinguishable from "no instance" by carrying the candidate list). + +**Triggers:** none for now. The strategy is monotone — once a +candidate is excluded it stays excluded, since we only act when +`tc.args` are already ground. (Future: if `tc.args` start univar, +register `OnTcUniResolved` triggers.) + +**Risk surface:** +- A user's instance-DB shape can change ("which instances are visible") + via imports/cloning. The strategy must use whatever + `TcInstance.get_all` returns at the moment the strategy fires — + consistent with how current Mode #1 already works. +- Picking a non-canonical "exactly one" must be robust against import + order. `etyargs_match` is structural; we are safe. + +## Test matrix (Phase 3) + +``` +tests/tc/multi-param-bare-ops.ec + - bare op, unique instance → resolves + - two competing instances → "ambiguous TC instance" error + - args still univar at start, + resolved later by usage → eventually resolves (deferred) + - no matching instance → "no instance" error +``` + +Plus the existing `tests/tc/`, `theories/`, and `tests/` regression +sweeps to ensure single-parameter TC behaviour does not change. + +## Future work (Phase 4-5) + +- **Functional dependencies** in TC syntax: `class ('a, 'b) embed | 'a 'b -> embed` + declares the dependency explicitly. The By-args strategy is then + *justified by the declaration*, not by enumeration. Also enables + duplicate-instance detection at instance-binding time. + +- **Anticipated future rows in the catalog:** + - TC arg inference from operator bodies (axiom RHSs that mention TC ops). + - Inference through hypotheses introduced by `intros`. + - `Tglob` / module-type carriers. + - Coercion across same-named ops in different TCs. + +Each new gap follows the same recipe: add a row, add a test, add a +strategy, route diagnostics through the same `Failed` path. diff --git a/doc/typeclasses.md b/doc/typeclasses.md new file mode 100644 index 0000000000..7b15679c23 --- /dev/null +++ b/doc/typeclasses.md @@ -0,0 +1,328 @@ +# Typeclasses — current status + +Status snapshot of the typeclass implementation on the `deploy-tc` branch. +Every feature listed under "Implemented" is exercised by a test under +[`tests/tc/`](../tests/tc/); pointers given inline. + +--- + +## Implemented + +### 1. Declaration + +A typeclass declares a set of operators and axioms parameterised over a +single carrier type, optionally inheriting from a parent class: + +``` +type class addmonoid = { + op idm : addmonoid + op (+) : addmonoid -> addmonoid -> addmonoid + + axiom addmA : associative (+) + axiom addmC : commutative (+) + axiom add0m : left_id idm (+) +}. + +type class group <: addmonoid = { + op opp : group -> group + axiom addmN : forall (x : group), opp x + x = idm +}. +``` + +- The carrier is referenced by the typeclass name itself inside the body + (`addmonoid`, `group`). +- Operators in the body are abstract; a concrete instance must realise + them. +- Axioms must have all their type/typeclass variables bound; underconstrained + axioms (`axiom foo : zero = zero`, where the carrier is left free) are + rejected with a clear `axiom 'foo' is type-ambiguous` message. + ([tests/tc/grandparent-op.ec](../tests/tc/grandparent-op.ec)) +- Inheritance is by `<:`. Multiple ancestors form a chain via `tc_prt`. +- See: [tests/tc/basic.ec](../tests/tc/basic.ec), + [tests/tc/inheritance.ec](../tests/tc/inheritance.ec). + +### 2. Multi-parameter typeclasses + +A typeclass may take leading type parameters in addition to the carrier: + +``` +type class ['a, 'b] embed = { + op proj : embed -> 'a + op inj : 'b -> embed + axiom dummy : true +}. +``` + +The carrier is still `embed`; `'a` and `'b` are auxiliary type parameters +of the class. +See: [tests/tc/multi-param.ec](../tests/tc/multi-param.ec). + +### 3. Instances + +An `instance` declaration realises a typeclass at a specific type: + +``` +op zero_int : int = 0. +op plus_int : int -> int -> int = Int.( + ). + +instance addmonoid as int_inst with int + op idm = zero_int + op (+) = plus_int. + +realize addmA by rewrite /plus_int; smt(). +realize addmC by rewrite /plus_int; smt(). +realize add0m by rewrite /plus_int /zero_int; smt(). +``` + +For a multi-parameter typeclass, the leading parameters are bound +positionally: + +``` +instance (int, bool) embed as pair_inst with (int * bool) + op proj = proj_pair + op inj = inj_pair. + +realize dummy by trivial. +``` + +- The instance name (`as int_inst`) is optional; an auto-generated name + is used otherwise. +- Multiple named instances for the same typeclass at different carrier + types coexist. + ([tests/tc/multi-instance.ec](../tests/tc/multi-instance.ec)) +- Each axiom must be discharged via `realize`. + +### 4. Polymorphic ops and lemmas over typeclasses + +``` +op double ['a <: addmonoid] (x : 'a) : 'a = x + x. + +lemma idm_idem ['a <: addmonoid] (x : 'a) : idm + x = x. +proof. by apply add0m. qed. +``` + +Operators and lemmas can be parameterised by a type variable constrained +by a typeclass; they are usable at any type with a matching instance. + +A type-parameter can also be constrained by a parametric typeclass that +references earlier type-parameters: + +``` +lemma round_trip + ['a, 'b, 'c <: ('a, 'b) embed] + (x : 'a) (y : 'b) : + proj<:'a, 'b, 'c> (inj<:'a, 'b, 'c> y) = x => + proj<:'a, 'b, 'c> (inj<:'a, 'b, 'c> y) = x. +proof. by apply proj_inj. qed. +``` + +### 5. Instantiation at use sites + +Explicit positional instantiation: + +``` +apply (idm_idem<:int> 5). +``` + +When a tparam is constrained by a typeclass and the user-supplied type +does not satisfy it, the diagnostic is clear: + +``` +type int does not satisfy typeclass constraint addmonoid +``` + +(Formerly produced a confusing "int doesn't match int" unification +diff.) +See: [tests/tc/explicit-tvi.ec](../tests/tc/explicit-tvi.ec). + +When the constraint references earlier tparams (`'c <: ('a, 'b) embed`), +the user-supplied bindings for `'a, 'b` are substituted before the +instance lookup, so a multi-parameter `apply +(round_trip<:int, bool, (int * bool)>)` works. +See: [tests/tc/multi-param.ec](../tests/tc/multi-param.ec). + +### 6. Sections + +The `declare type t <: tc.` form abstracts a TC-constrained carrier +inside a section. Operators and lemmas using `t` survive section close +as TC-polymorphic forms: + +``` +section. + declare type t <: addmonoid. + + op double (x : t) : t = x + x. + + lemma double_idm : double idm = idm. + proof. by rewrite /double add0m. qed. +end section. + +(* After close: *) +op test ['a <: addmonoid] (x : 'a) : 'a = double x. +``` + +See: [tests/tc/section.ec](../tests/tc/section.ec), +[tests/tc/declare-type.ec](../tests/tc/declare-type.ec). + +### 7. Cloning abstract theories + +An abstract theory parametrised by a TC-constrained carrier can be +cloned with a concrete instance carrier; the substitution threads +through TC witnesses, and the cloned operators reduce via the matching +instance: + +``` +abstract theory T. + type t <: addmonoid. + op double (x : t) : t = x + x. +end T. + +clone T as TI with type t = int. + +(* TI.double zero_int reduces to plus_int zero_int zero_int. *) +``` + +See: [tests/tc/clone-with-instance.ec](../tests/tc/clone-with-instance.ec), +[tests/tc/clone.ec](../tests/tc/clone.ec). + +### 8. Reduction (`delta_tc`) + +The reduction info exposes a `delta_tc` flag. When set, TC operators +applied at concrete (non-abstract) carriers reduce to the corresponding +instance body. When the witness was substituted to `\`Abs ` +(e.g. via theory cloning), the reducer infers the matching instance +on-the-fly. + +### 9. SMT integration + +When `smt()` (or `smt(...)`) is called over a goal whose context contains +type parameters constrained by typeclasses, every axiom of those +typeclasses (and their ancestors, deduplicated) is automatically added +to the Why3 task. This means `smt()` (no hints) closes goals over +abstract carriers that previously required `smt(addmA addmC add0m ...)`. + +For concrete carriers, the `delta_tc` pre-reduction in the SMT init +collapses TC operators to their instance bodies before translation. + +See: [tests/tc/smt.ec](../tests/tc/smt.ec). + +### 10. Diamond and multi-level inheritance + +``` +type class base = { ... } +type class tc1 <: base = { ... } +type class tc2 <: base = { ... } +type class tc3 <: tc1 = { ... } +``` + +The ancestor walk reaches `base` from `tc3` (lift = 2) without +duplication. SMT auto-axiom inclusion deduplicates by axiom path. + +See: [tests/tc/diamond.ec](../tests/tc/diamond.ec). + +### 11. Pretty-printing + +`type t.` prints as `type t.` for unconstrained abstract types and as +`type t <: addmonoid.` when constrained. Empty etyarg/witness brackets +are elided: `int[int_inst]` instead of `int[int_inst[]]`, +`addmonoid` instead of `addmonoid[]`. The `<:tc>` suffix on operators +appears only when the witness is a non-trivial reference (univar +placeholders, abstract carriers, parametric instances). + +--- + +## Known limitations + +### Polymorphic-body bare ops on parametric-carrier typeclasses + +Inside a polymorphic body — say a lemma `['a, 'b, 'c <: ('a, 'b) embed] +... proj (inj y) ...` — bare ops still need explicit tvi +(`proj<:'a, 'b, 'c>`). The carrier is a type parameter, not a concrete +type, so the By-args strategy (which picks an instance from the +database) does not fire. At ground call sites the carrier is inferred +automatically; see [tests/tc/multi-param-bare-ops.ec](../tests/tc/multi-param-bare-ops.ec) +and [doc/typeclasses-inference.md](typeclasses-inference.md). + +### Tuple/function carriers in instance declarations + +Parser-side, `instance ... with (int * bool)` is accepted; the +resulting carrier type does flow through. But the upstream "carrier" +typing path does not currently accept declaring an instance directly on +a Tuple or Tfun type unless wrapped — see the `assert false` in +`subst_tcw` ([src/ecSubst.ml:226](../src/ecSubst.ml#L226)) which is +guarded behind an upstream rejection. This is a latent issue if upstream +loosens. + +### Reverse-rewrite of bare-metavariable lemmas + +A pattern like `rewrite -{1 2 3}mulrr` where `mulrr : forall x, x*x = x` +picks the first (largest) successful unification of `x`, which often +yields fewer occurrences than the user expects. Workaround: explicit +arg, `rewrite -{1 2 3}(mulrr (x + x))`. This is a pre-existing +matcher behaviour, not TC-specific (reproduces on `main` without +typeclasses); fix would touch the rewrite engine more broadly. + +--- + +## Examples in `examples/tcstdlib/` and `examples/typeclasses/` + +- [TcMonoid.ec](../examples/tcstdlib/TcMonoid.ec) — compiles cleanly. +- [TcRing.ec](../examples/tcstdlib/TcRing.ec) — compiles cleanly. +- [examples/typeclasses/monoidtc.ec](../examples/typeclasses/monoidtc.ec) + and + [examples/typeclasses/typeclass.ec](../examples/typeclasses/typeclass.ec) + — compile cleanly. + +--- + +## Files of interest + +| Concern | File | +|-------------------------------|-------------------------------| +| AST: `tcwitness`, etyargs | [src/ecAst.ml](../src/ecAst.ml) | +| Typeclass declarations | [src/ecScope.ml `add_class`](../src/ecScope.ml) | +| Instance declarations | [src/ecScope.ml `add_instance`](../src/ecScope.ml) | +| TC inference / ancestor walk | [src/ecTypeClass.ml](../src/ecTypeClass.ml) | +| Unifier `\`TcCtt` resolution | [src/ecUnify.ml](../src/ecUnify.ml) | +| Section close | [src/ecSection.ml `generalize_*`](../src/ecSection.ml) | +| Theory clone replay | [src/ecTheoryReplay.ml](../src/ecTheoryReplay.ml) | +| Reduction (`delta_tc`) | [src/ecReduction.ml](../src/ecReduction.ml) | +| SMT auto-axiom inclusion | [src/ecSmt.ml `trans_tc_axioms`](../src/ecSmt.ml) | +| Pretty-printing | [src/ecPrinting.ml](../src/ecPrinting.ml) | +| Tvi diagnostic | [src/ecProofTyping.ml `pf_check_tvi`](../src/ecProofTyping.ml) | + +--- + +## Test suite + +Positive tests are under [`tests/tc/`](../tests/tc/) (scenario `unit`); +negative regression tests — files that must fail compilation with a +specific diagnostic — are under [`tests/tc-ko/`](../tests/tc-ko/) +(scenario `tc-ko`). + +| File | What it covers | +|----------------------------|-------------------------------------------------| +| `basic.ec` | Minimal class + instance + lemma | +| `clone.ec` | Cloning a theory containing a TC declaration | +| `clone-with-instance.ec` | Cloning an abstract theory with TC carrier | +| `declare-type.ec` | Section closure with `declare type t <: tc` | +| `diamond.ec` | Diamond inheritance + SMT auto-axioms | +| `explicit-tvi.ec` | Explicit `<:int>` and bare apply | +| `grandparent-op.ec` | Underconstrained-axiom diagnostic + workarounds | +| `inheritance.ec` | Two-level subclass chain | +| `instance.ec` | Multiple ops/axioms in an instance | +| `multi-instance.ec` | Two named instances for one TC at different types | +| `multi-param.ec` | `('a, 'b) embed` + polymorphic lemma + instance | +| `multi-param-bare-ops.ec` | Bare-op carrier inference for multi-param TCs | + +Negative tests under `tests/tc-ko/`: + +| File | Asserted error message | +|------------------------------|-------------------------------------------------| +| `bad-tvi.ec` | `type int does not satisfy typeclass constraint addmonoid` | +| `underconstrained-axiom.ec` | `axiom 'tc3_extra' is type-ambiguous: ...` | +| `ambiguous-instance.ec` | `ambiguous typeclass instance for embed; candidates: ...` | +| `parametric.ec` | Parametric TC `['a <: tc] action` | +| `print.ec` | `print` does not crash on TC entities | +| `section.ec` | Typeclass declared inside a section | +| `smt.ec` | SMT over abstract carriers (with/without hints) | diff --git a/dune-project b/dune-project index 85f142616e..ec8d76f29d 100644 --- a/dune-project +++ b/dune-project @@ -24,4 +24,5 @@ (why3 (and (>= 1.8.0) (< 1.9))) yojson (zarith (>= 1.10)) -)) + ) +) diff --git a/examples/tcalgebra/TcBigalg.ec b/examples/tcalgebra/TcBigalg.ec new file mode 100644 index 0000000000..70162669cb --- /dev/null +++ b/examples/tcalgebra/TcBigalg.ec @@ -0,0 +1,357 @@ +pragma +implicits. + +(* -------------------------------------------------------------------- *) +require import AllCore List StdOrder. +require import TcMonoid TcRing TcBigop. + +import IntOrder. + +(* ==================================================================== *) +(* Big sums over an additive group. Mirrors *) +(* [theories/algebra/Bigalg.ec:BigZModule] but as a TC section on *) +(* [addgroup] carriers. *) +(* ==================================================================== *) +section. +declare type t <: addgroup. + +(* -------------------------------------------------------------------- *) +lemma sumrD ['a] (P : 'a -> bool) (F1 F2 : 'a -> t) (r : 'a list) : + (big P F1 r) + (big P F2 r) = big P (fun x => F1 x + F2 x) r. +proof. by rewrite big_split. qed. + +(* -------------------------------------------------------------------- *) +lemma sumrN ['a] (P : 'a -> bool) (F : 'a -> t) (r : 'a list) : + - (big P F r) = big P (fun x => -(F x)) r. +proof. by apply/(big_endo oppr0 opprD). qed. + +(* -------------------------------------------------------------------- *) +lemma sumrB ['a] (P : 'a -> bool) (F1 F2 : 'a -> t) (r : 'a list) : + (big P F1 r) - (big P F2 r) = big P (fun x => F1 x - F2 x) r. +proof. by rewrite sumrN sumrD; apply/eq_bigr => /=. qed. + +(* -------------------------------------------------------------------- *) +lemma sumr_const ['a] (P : 'a -> bool) (x : t) (s : 'a list) : + big P (fun _ => x) s = intmul x (count P s). +proof. by rewrite big_const intmulpE 1:count_ge0 // -iteropE. qed. + +lemma sumri_const (k : t) (n m : int) : + n <= m => bigi predT (fun _ => k) n m = intmul k (m - n). +proof. by move=> h; rewrite sumr_const count_predT size_range /#. qed. + +(* -------------------------------------------------------------------- *) +lemma sumr_undup ['a] (P : 'a -> bool) (F : 'a -> t) (s : 'a list) : + big P F s = big P (fun a => intmul (F a) (count (pred1 a) s)) (undup s). +proof. +rewrite big_undup; apply/eq_bigr => x _ /=. +by rewrite intmulpE ?count_ge0 iteropE. +qed. + +(* -------------------------------------------------------------------- *) +lemma telescoping_sum (F : int -> t) (m n : int) : + m <= n => F m - F n = bigi predT (fun i => F i - F (i+1)) m n. +proof. +move=> /ler_eqVlt [<<- | hmn]. ++ by rewrite big_geq 1:// subrr. +rewrite -sumrB (@big_ltn m n F) 1:// /=. +have heq: n = n - 1 + 1 by ring. +rewrite heq (@big_int_recr (n-1) m) 1:/# -heq /=. +rewrite (@big_reindex _ _ (fun x => x - 1) (fun x => x + 1) (range m (n - 1))) //. +have ->: (transpose Int.(+) 1) = ((+) 1). ++ by apply: fun_ext => x; ring. +have ->: predT \o transpose Int.(+) (-1) = predT by done. +by rewrite /(\o) /= -(@range_addl m n 1) (@addrC _ (F n)) subr_add2r. +qed. + +lemma telescoping_sum_down (F : int -> t) (m n : int) : + m <= n => F n - F m = bigi predT (fun i => F (i+1) - F i) m n. +proof. +move=> hmn; have /= := telescoping_sum (fun i => -F i) _ _ hmn. +by rewrite opprK addrC => ->; apply eq_big => //= i _; rewrite opprK addrC. +qed. + +end section. + +(* ==================================================================== *) +(* Big sums over a [comring] carrier. Mirrors *) +(* [theories/algebra/Bigalg.ec:BigComRing.BAdd] (additive view). *) +(* ==================================================================== *) +section. +declare type t <: comring. + +(* -------------------------------------------------------------------- *) +lemma sumr_1 ['a] (P : 'a -> bool) (s : 'a list) : + bigA P (fun _ => oner<:t>) s = ofint (count P s). +proof. by apply/sumr_const. qed. + +(* -------------------------------------------------------------------- *) +lemma mulr_suml ['a] (P : 'a -> bool) (F : 'a -> t) (s : 'a list) (x : t) : + (bigA P F s) * x = bigA P (fun i => F i * x) s. +proof. by rewrite big_distrl //; (apply/mul0r || apply/mulrDl). qed. + +lemma mulr_sumr ['a] (P : 'a -> bool) (F : 'a -> t) (s : 'a list) (x : t) : + x * (bigA P F s) = bigA P (fun i => x * F i) s. +proof. by rewrite big_distrr //; (apply/mulr0 || apply/mulrDr). qed. + +lemma divr_suml ['a] (P : 'a -> bool) (F : 'a -> t) (s : 'a list) (x : t) : + (bigA P F s) / x = bigA P (fun i => F i / x) s. +proof. by rewrite mulr_suml; apply/eq_bigr. qed. + +(* -------------------------------------------------------------------- *) +lemma sum_pair_dep ['a 'b] (u : 'a -> t) (v : 'a -> 'b -> t) (J : ('a * 'b) list) : + uniq J => + bigA predT (fun (ij : 'a * 'b) => u ij.`1 * v ij.`1 ij.`2) J + = bigA predT + (fun i => u i * bigA predT + (fun ij : _ * _ => v ij.`1 ij.`2) + (filter (fun ij : _ * _ => ij.`1 = i) J)) + (undup (unzip1 J)). +proof. +move=> uqJ; rewrite big_pair // &(eq_bigr) => /= a _. +by rewrite mulr_sumr !big_filter &(eq_bigr) => -[a' b] /= ->>. +qed. + +lemma sum_pair ['a 'b] (u : 'a -> t) (v : 'b -> t) (J : ('a * 'b) list) : + uniq J => + bigA predT (fun (ij : 'a * 'b) => u ij.`1 * v ij.`2) J + = bigA predT + (fun i => u i * bigA predT v + (unzip2 (filter (fun ij : _ * _ => ij.`1 = i) J))) + (undup (unzip1 J)). +proof. +move=> uqJ; rewrite (@sum_pair_dep u (fun _ => v)) // &(eq_bigr) /=. +by move=> a _ /=; congr; rewrite big_map predT_comp /(\o). +qed. + +(* -------------------------------------------------------------------- *) +lemma mulr_big ['a 'b] + (P : 'a -> bool) (Q : 'b -> bool) (f : 'a -> t) (g : 'b -> t) + (r : 'a list) (s : 'b list) : + bigA P f r * bigA Q g s + = bigA P (fun x => bigA Q (fun y => f x * g y) s) r. +proof. +elim: r s => [|x r ih] s; first by rewrite big_nil mul0r. +rewrite !big_cons; case: (P x) => Px; last by rewrite ih. +by rewrite mulrDl -ih mulr_sumr. +qed. + +(* -------------------------------------------------------------------- *) +lemma mulr_const_cond ['a] p (s : 'a list) (c : t) : + bigM<:'a, t> p (fun _ => c) s = exp c (count p s). +proof. +rewrite big_const -iteropE /exp. +by rewrite IntOrder.ltrNge count_ge0. +qed. + +lemma mulr_const ['a] (s : 'a list) (c : t) : + bigM<:'a, t> predT (fun _ => c) s = exp c (size s). +proof. by rewrite mulr_const_cond count_predT. qed. + +(* -------------------------------------------------------------------- *) +lemma subrXX (x y : t) n : 0 <= n => + exp x n - exp y n = (x - y) * (bigiA predT (fun i => exp x (n - 1 - i) * exp y i) 0 n). +proof. +case: n => [|n ge0_n _]; first by rewrite !expr0 big_geq // subrr mulr0. +rewrite mulrBl !(big_distrr mulr0 mulrDr). +rewrite big_int_recl // big_int_recr //= !expr0 /=. +rewrite !(mulr1, mul1r) -!exprS // opprD !addrA; congr. +rewrite -addrA sumrB /= big_seq big1 ?addr0 //=. +move=> i /mem_range rg_i; rewrite mulrA -exprS 1:/# mulrCA. +by rewrite -exprS 1:/# subr_eq0; do 2! congr => /#. +qed. + +end section. + +(* ==================================================================== *) +(* Big sums / products under an ordered domain. Mirrors *) +(* [theories/algebra/Bigalg.ec:BigOrder]. *) +(* ==================================================================== *) +require import TcNumber. + +section. +declare type t <: tcrealdomain. + +lemma ler_sum ['a] (P : 'a -> bool) (F1 F2 : 'a -> t) s : + (forall a, P a => F1 a <= F2 a) + => (bigA P F1 s <= bigA P F2 s). +proof. +apply: (@big_ind2 (fun (x y : t) => x <= y)) => //=. + by apply/ler_add. +qed. + +lemma sumr_ge0 ['a] (P : 'a -> bool) (F : 'a -> t) s : + (forall a, P a => zero <= F a) + => zero <= bigA P F s. +proof. +move=> h; apply: (@big_ind (fun (x : t) => zero <= x)) => //=. + by apply/addr_ge0. +qed. + +lemma sub_ler_sum ['a] (P1 P2 : 'a -> bool) (F1 F2 : 'a -> t) s : + (forall x, P1 x => P2 x) => + (forall x, P1 x => F1 x <= F2 x) => + (forall x, P2 x => !P1 x => zero <= F2 x) => + bigA P1 F1 s <= bigA P2 F2 s. +proof. +move => sub_P1_P2 le_F1_F2 pos_F2; rewrite (@bigID P2 _ P1). +have -> : predI P2 P1 = P1 by smt(). +by rewrite -(addr0 (bigA P1 F1 s)) ler_add ?ler_sum // sumr_ge0 /#. +qed. + +lemma sumr_norm ['a] P (F : 'a -> t) s : + (forall x, P x => zero <= F x) => + bigA P (fun x => `|F x|) s = bigA P F s. +proof. +by move=> ge0_F; apply: eq_bigr => /= a Pa; rewrite ger0_norm /#. +qed. + +lemma ler_sum_seq ['a] (P : 'a -> bool) (F1 F2 : 'a -> t) s : + (forall a, mem s a => P a => F1 a <= F2 a) + => (bigA P F1 s <= bigA P F2 s). +proof. +move=> h; rewrite !(@big_seq_cond P). +by rewrite ler_sum=> //= x []; apply/h. +qed. + +lemma sumr_ge0_seq ['a] (P : 'a -> bool) (F : 'a -> t) s : + (forall a, mem s a => P a => zero <= F a) + => zero <= bigA P F s. +proof. +move=> h; rewrite !(@big_seq_cond P). +by rewrite sumr_ge0=> //= x []; apply/h. +qed. + +lemma prodr_ge0 ['a] (P : 'a -> bool) (F : 'a -> t) (s : 'a list) : + (forall a, P a => zero <= F a) + => zero <= bigM P F s. +proof. +move=> h; apply: (@big_ind (fun (x : t) => zero <= x)) => //=. + by apply/mulr_ge0. +qed. + +lemma prodr_gt0 ['a] (P : 'a -> bool) (F : 'a -> t) (s : 'a list) : + (forall a, P a => zero < F a) + => zero < bigM P F s. +proof. +move=> h; apply: (@big_ind (fun (x : t) => zero < x)) => //=. + by apply/mulr_gt0. +qed. + +lemma ler_prod ['a] (P : 'a -> bool) (F1 F2 : 'a -> t) s : + (forall a, P a => zero <= F1 a <= F2 a) + => (bigM P F1 s <= bigM P F2 s). +proof. +move=> h; elim: s => [|x s ih]; first by rewrite !big_nil lerr. +rewrite !big_cons; case: (P x)=> // /h [ge0F1x leF12x]. +by apply/ler_pmul=> //; apply/prodr_ge0=> a /h []. +qed. + +lemma prodr_ge0_seq ['a] (P : 'a -> bool) (F : 'a -> t) (s : 'a list) : + (forall a, mem s a => P a => zero <= F a) + => zero <= bigM P F s. +proof. +move=> h; rewrite !(@big_seq_cond P). +by rewrite prodr_ge0=> //= x []; apply/h. +qed. + +lemma prodr_gt0_seq ['a] (P : 'a -> bool) (F : 'a -> t) (s : 'a list) : + (forall a, mem s a => P a => zero < F a) + => zero < bigM P F s. +proof. +move=> h; rewrite !(@big_seq_cond P). +by rewrite prodr_gt0=> //= x []; apply/h. +qed. + +lemma ler_prod_seq ['a] (P : 'a -> bool) (F1 F2 : 'a -> t) s : + (forall a, mem s a => P a => zero <= F1 a <= F2 a) + => (bigM P F1 s <= bigM P F2 s). +proof. +move=> h; rewrite !(@big_seq_cond P). +by rewrite ler_prod=> //= x []; apply/h. +qed. + +lemma big_normr ['a] P (F : 'a -> t) s : + `|bigA P F s| <= bigA P (fun x => `|F x|) s. +proof. +elim: s => [|x s ih]; first by rewrite !big_nil normr0. +rewrite !big_cons /=; case: (P x) => // Px. +have /ler_trans := ler_norm_add (F x) (bigA P F s); apply. +by rewrite ler_add2l. +qed. + +lemma gt0_prodr_seq ['a] (P : 'a -> bool) (F : 'a -> t) (s : 'a list) : + (forall (a : 'a), a \in s => P a => zero <= F a) => + zero < bigM P F s => + (forall (a : 'a), a \in s => P a => zero < F a). +proof. +elim: s => // x s IHs F_ge0; rewrite big_cons. +have {IHs} IHs := IHs _; first by smt(). +case: (P x) => [Px F_big_gt0 a a_x_s Pa| nPx /IHs]; 2:smt(). +smt(pmulr_gt0 prodr_ge0_seq). +qed. + +lemma prodr_eq0 ['a] P (F : 'a -> t) s : + (exists x, P x /\ x \in s /\ F x = zero) + <=> bigM<:'a, t> P F s = zero. +proof. split. ++ case=> x [# Px x_in_s z_Fx]; rewrite (@big_rem _ _ _ x) //. + by rewrite Px /= z_Fx mul0r. ++ elim: s => [|x s ih] /=; 1: by rewrite big_nil oner_neq0. + rewrite big_cons /=; case: (P x) => Px; last first. + - by move/ih; case=> y [# Py ys z_Fy]; exists y; rewrite Py ys z_Fy. + rewrite mulf_eq0; case=> [z_Fx|]; first by exists x. + by move/ih; case=> y [# Py ys z_Fy]; exists y; rewrite Py ys z_Fy. +qed. + +lemma ler_pexpn2r n (x y : t) : + 0 < n => zero <= x => zero <= y => (exp x n <= exp y n) <=> (x <= y). +proof. +move=> gt0_n ge0_x ge0_y; split => [|h]; last first. +- by apply/ler_pexp=> //; apply/ltzW. +case: (x = zero) => [->>|nz_x]. +- by rewrite expr0n 1:ltzW. +rewrite -subr_ge0 subrXX 1:ltzW // pmulr_lge0 ?subr_ge0 //=. +rewrite {2}(_ : n = n - 1 + 1) 1:#ring big_int_recr /= 1:/#. +rewrite expr0 /= ltr_spaddr ?mul1r; 1: by rewrite expr_gt0 ltr_neqAle /#. +by rewrite sumr_ge0 => /= i _; rewrite mulr_ge0 ?expr_ge0. +qed. + +lemma sum_expr (p : t) n : 0 <= n => + (oner - p) * bigiA predT (fun i => exp p i) 0 n = oner - exp p n. +proof. +move=> hn; have /eq_sym := subrXX oner p n hn. +rewrite expr1z // => <-; congr. +by apply: eq_big_int => i _ /=; rewrite expr1z mul1r. +qed. + +lemma sum_expr_le (p : t) n : + 0 <= n + => zero <= p < oner + => (oner - p) * bigiA predT (fun i => exp p i) 0 n <= oner. +proof. +move=> ge0_n [ge0_p lt1_p]; rewrite sum_expr //. +by rewrite ler_subl_addr ler_paddr // expr_ge0. +qed. + +lemma sum_iexpr_le (p : t) n : zero <= p < oner => + exp (oner - p) 2 * bigiA predT (fun i => ofint i * exp p i) 0 n <= oner. +proof. +case=> [ge0_p lt1_p]; elim/natind: n => [n le0_n|n ge0_n ih]. ++ by rewrite big_geq // mulr0. +rewrite big_ltn 1:/# /= ofint0 mul0r add0r. +pose F := fun j => exp p j + p * ((ofint<:t> j - oner) * exp p (j - 1)). +rewrite (@eq_big_int _ _ _ F) => /= [i [gt0_i lti]|]. +- by rewrite /F mulrCA -expr_pred 1:/# mulrBl mul1r addrC subrK. +rewrite -sumrD -mulr_sumr mulrDr. +apply: (ler_trans ((oner - p) + p)); last by rewrite lerr_eq subrK. +apply: ler_add. +- rewrite expr2 -mulrA ler_pimulr 1:subr_ge0 1:ltrW //. + have le := sum_expr_le p (n+1) _ _ => //; first move=> /#. + rewrite &(ler_trans _ _ le) ler_wpmul2l 1:subr_ge0 1:ltrW //. + by rewrite (@big_ltn 0) 1:/# /= expr0 ler_paddl. +rewrite mulrCA ler_pimulr // &(ler_trans _ _ ih). +rewrite ler_wpmul2l; first by rewrite expr_ge0 subr_ge0 ltrW. +rewrite &(lerr_eq) (@big_addn 0 _ 1) &(eq_big_int) /=. +by move=> i [ge0_i _]; rewrite ofintS // addrAC subrr add0r. +qed. + +end section. diff --git a/examples/tcalgebra/TcBigop.ec b/examples/tcalgebra/TcBigop.ec new file mode 100644 index 0000000000..02d3fa6cfe --- /dev/null +++ b/examples/tcalgebra/TcBigop.ec @@ -0,0 +1,602 @@ +(* This API has been mostly inspired from the [bigop] library of the + * ssreflect Coq extension. *) + +pragma +implicits. + +(* -------------------------------------------------------------------- *) +require import AllCore List Ring TcMonoid. + +import Ring.IntID. + +(* -------------------------------------------------------------------- *) +section. +declare type t <: monoid. + +(* -------------------------------------------------------------------- *) +op big (P : 'a -> bool) (F : 'a -> t) (r : 'a list) = + foldr (+) idm (map F (filter P r)). + +(* -------------------------------------------------------------------- *) +abbrev bigi (P : int -> bool) (F : int -> t) i j = + big P F (range i j). + +(* -------------------------------------------------------------------- *) +lemma big_nil (P : 'a -> bool) (F : 'a -> t): big P F [] = idm. +proof. by []. qed. + +(* -------------------------------------------------------------------- *) +lemma big_cons (P : 'a -> bool) (F : 'a -> t) x s: + big P F (x :: s) = if P x then F x + big P F s else big P F s. +proof. by rewrite {1}/big /= (@fun_if (map F)); case (P x). qed. + +lemma big_consT (F : 'a -> t) x s: + big predT F (x :: s) = F x + big predT F s. +proof. by apply/big_cons. qed. + +(* -------------------------------------------------------------------- *) +lemma big_rec (K : t -> bool) r P (F : 'a -> t): + K idm => (forall i x, P i => K x => K (F i + x)) => K (big P F r). +proof. + move=> K0 Kop; elim: r => //= i r; rewrite big_cons. + by case (P i) => //=; apply/Kop. +qed. + +lemma big_ind (K : t -> bool) r P (F : 'a -> t): + (forall x y, K x => K y => K (x + y)) + => K idm => (forall i, P i => K (F i)) + => K (big P F r). +proof. + move=> Kop Kidx K_F; apply/big_rec => //. + by move=> i x Pi Kx; apply/Kop => //; apply/K_F. +qed. + +lemma big_rec2: + forall (K : t -> t -> bool) r P (F1 F2 : 'a -> t), + K idm idm + => (forall i y1 y2, P i => K y1 y2 => K (F1 i + y1) (F2 i + y2)) + => K (big P F1 r) (big P F2 r). +proof. + move=> K r P F1 F2 KI KF; elim: r => //= i r IHr. + by rewrite !big_cons; case (P i) => ? //=; apply/KF. +qed. + +lemma big_ind2: + forall (K : t -> t -> bool) r P (F1 F2 : 'a -> t), + (forall x1 x2 y1 y2, K x1 x2 => K y1 y2 => K (x1 + y1) (x2 + y2)) + => K idm idm + => (forall i, P i => K (F1 i) (F2 i)) + => K (big P F1 r) (big P F2 r). +proof. + move=> K r P F1 F2 Kop KI KF; apply/big_rec2 => //. + by move=> i x1 x2 Pi Kx1x2; apply/Kop => //; apply/KF. +qed. + +(* -------------------------------------------------------------------- *) +lemma big_endo (f : t -> t): + f idm = idm + => (forall (x y : t), f (x + y) = f x + f y) + => forall r P (F : 'a -> t), + f (big P F r) = big P (f \o F) r. +proof. + (* FIXME: should be a consequence of big_morph *) + move=> fI fM; elim=> //= i r IHr P F; rewrite !big_cons. + by case (P i) => //=; rewrite 1?fM IHr. +qed. + +(* -------------------------------------------------------------------- *) +lemma big_map ['a 'b] (h : 'b -> 'a) (P : 'a -> bool) F s: + big P F (map h s) = big (P \o h) (F \o h) s. +proof. by elim: s => // x s; rewrite map_cons !big_cons=> ->. qed. + +lemma big_mapT ['a 'b] (h : 'b -> 'a) F s: (* -> big_map_predT *) + big predT F (map h s) = big predT (F \o h) s. +proof. by rewrite big_map. qed. + +(* -------------------------------------------------------------------- *) +lemma big_comp ['a] (h : t -> t) (P : 'a -> bool) F s: + h idm = idm => morphism_2 h (+) (+) => + h (big P F s) = big P (h \o F) s. +proof. + move=> Hidm Hh;elim: s => // x s; rewrite !big_cons => <-. + by rewrite /(\o) -Hh;case (P x) => //. +qed. + +(* -------------------------------------------------------------------- *) +lemma big_nth x0 (P : 'a -> bool) (F : 'a -> t) s: + big P F s = bigi (P \o (nth x0 s)) (F \o (nth x0 s)) 0 (size s). +proof. by rewrite -{1}(@mkseq_nth x0 s) /mkseq big_map. qed. + +(* -------------------------------------------------------------------- *) +lemma big_const (P : 'a -> bool) x s: + big P (fun i => x) s = iter (count P s) ((+) x) idm. +proof. + elim: s=> [|y s ih]; [by rewrite iter0 | rewrite big_cons /=]. + by rewrite ih; case (P y) => //; rewrite addzC iterS // count_ge0. +qed. + +(* -------------------------------------------------------------------- *) +lemma big_seq1 (F : 'a -> t) x: big predT F [x] = F x. +proof. by rewrite big_cons big_nil addm0. qed. + +(* -------------------------------------------------------------------- *) +lemma big_mkcond (P : 'a -> bool) (F : 'a -> t) s: + big P F s = big predT (fun i => if P i then F i else idm) s. +proof. + elim: s=> // x s ih; rewrite !big_cons -ih /predT /=. + by case (P x)=> //; rewrite add0m. +qed. + +(* -------------------------------------------------------------------- *) +lemma big_filter (P : 'a -> bool) F s: + big predT F (filter P s) = big P F s. +proof. by elim: s => //= x s; case (P x)=> //; rewrite !big_cons=> -> ->. qed. + +(* -------------------------------------------------------------------- *) +lemma big_filter_cond (P1 P2 : 'a -> bool) F s: + big P2 F (filter P1 s) = big (predI P1 P2) F s. +proof. by rewrite -big_filter -(@big_filter _ _ s) predIC filter_predI. qed. + +(* -------------------------------------------------------------------- *) +lemma eq_bigl (P1 P2 : 'a -> bool) (F : 'a -> t) s: + (forall i, P1 i <=> P2 i) + => big P1 F s = big P2 F s. +proof. by move=> h; rewrite /big (eq_filter h). qed. + +(* -------------------------------------------------------------------- *) +lemma eq_bigr (P : 'a -> bool) (F1 F2 : 'a -> t) s: + (forall i, P i => F1 i = F2 i) + => big P F1 s = big P F2 s. +proof. (* FIXME: big_rec2 *) + move=> eqF; elim: s=> // x s; rewrite !big_cons=> <-. + by case (P x)=> // /eqF <-. +qed. + +(* -------------------------------------------------------------------- *) +lemma big_distrl ['a] (op_ : t -> t -> t) (P : 'a -> bool) F s u: + left_zero idm op_ + => left_distributive op_ (+) + => op_ (big P F s) u = big P (fun a => op_ (F a) u) s. +proof. + move=> mulm1 mulmDl; pose G := fun x => op_ x u. + move: (big_comp G P) => @/G /= -> //. + by rewrite mulm1. by move=> t1 t2; rewrite mulmDl. +qed. + +lemma big_distrr ['a] (op_ : t -> t -> t) (P : 'a -> bool) F s u: + right_zero idm op_ + => right_distributive op_ (+) + => op_ u (big P F s) = big P (fun a => op_ u (F a)) s. +proof. + move=> mul1m mulmDr; pose G := fun x => op_ u x. + move: (big_comp G P) => @/G /= -> //. + by rewrite mul1m. by move=> t1 t2; rewrite mulmDr. +qed. + +lemma big_distr ['a 'b] (op_ : t -> t -> t) + (P1 : 'a -> bool) (P2 : 'b -> bool) F1 s1 F2 s2 : + commutative op_ + => left_zero idm op_ + => left_distributive op_ (+) + => op_ (big P1 F1 s1) (big P2 F2 s2) = + big P1 (fun a1 => big P2 (fun a2 => op_ (F1 a1) (F2 a2)) s2) s1. +proof. + move=> mulmC mulm1 mulmDl; rewrite big_distrl //. + apply/eq_bigr=> i _ /=; rewrite big_distrr //. + by move=> x; rewrite mulmC mulm1. + by move=> x y z; rewrite !(mulmC x) mulmDl. +qed. + +(* -------------------------------------------------------------------- *) +lemma big_andbC (P Q : 'a -> bool) (F : 'a -> t) s: + big (fun x => P x /\ Q x) F s = big (fun x => Q x /\ P x) F s. +proof. by apply/eq_bigl=> i. qed. + +(* -------------------------------------------------------------------- *) +lemma eq_big (P1 P2 : 'a -> bool) (F1 F2 : 'a -> t) s: + (forall i, P1 i <=> P2 i) + => (forall i, P1 i => F1 i = F2 i) + => big P1 F1 s = big P2 F2 s. +proof. by move=> /eq_bigl <- /eq_bigr <-. qed. + +(* -------------------------------------------------------------------- *) +lemma congr_big r1 r2 P1 P2 (F1 F2 : 'a -> t): + r1 = r2 + => (forall x, P1 x <=> P2 x) + => (forall i, P1 i => F1 i = F2 i) + => big P1 F1 r1 = big P2 F2 r2. +proof. by move=> <-; apply/eq_big. qed. + +(* -------------------------------------------------------------------- *) +lemma big_hasC (P : 'a -> bool) (F : 'a -> t) s: !has P s => + big P F s = idm. +proof. + rewrite -big_filter has_count -size_filter. + by rewrite ltz_def size_ge0 /= => /size_eq0 ->. +qed. + +(* -------------------------------------------------------------------- *) +lemma big_pred0_eq (F : 'a -> t) s: big pred0 F s = idm. +proof. by rewrite big_hasC // has_pred0. qed. + +(* -------------------------------------------------------------------- *) +lemma big_pred0 (P : 'a -> bool) (F : 'a -> t) s: + (forall i, P i <=> false) => big P F s = idm. +proof. by move=> h; rewrite -(@big_pred0_eq F s); apply/eq_bigl. qed. + +(* -------------------------------------------------------------------- *) +lemma big_cat (P : 'a -> bool) (F : 'a -> t) s1 s2: + big P F (s1 ++ s2) = big P F s1 + big P F s2. +proof. + rewrite !(@big_mkcond P); elim: s1 => /= [|i s1 ih]. + by rewrite (@big_nil P F) add0m. + by rewrite !big_cons /(predT i) /= ih addmA. +qed. + +(* -------------------------------------------------------------------- *) +lemma big_catl (P : 'a -> bool) (F : 'a -> t) s1 s2: !has P s2 => + big P F (s1 ++ s2) = big P F s1. +proof. by rewrite big_cat => /big_hasC ->; rewrite addm0. qed. + +(* -------------------------------------------------------------------- *) +lemma big_catr (P : 'a -> bool) (F : 'a -> t) s1 s2: !has P s1 => + big P F (s1 ++ s2) = big P F s2. +proof. by rewrite big_cat => /big_hasC ->; rewrite add0m. qed. + +(* -------------------------------------------------------------------- *) +lemma big_rcons (P : 'a -> bool) (F : 'a -> t) s x: + big P F (rcons s x) = if P x then big P F s + F x else big P F s. +proof. + by rewrite -cats1 big_cat big_cons big_nil; case: (P x); rewrite !addm0. +qed. + +(* -------------------------------------------------------------------- *) +lemma eq_big_perm (P : 'a -> bool) (F : 'a -> t) s1 s2: + perm_eq s1 s2 => big P F s1 = big P F s2. +proof. + move=> /perm_eqP; rewrite !(@big_mkcond P). + elim s1 s2 => [|i s1 ih1] s2 eq_s12. + + case: s2 eq_s12=> // i s2 h. + by have := h (pred1 i)=> //=; smt(count_ge0). + have r2i: mem s2 i by rewrite -has_pred1 has_count -eq_s12 #smt:(count_ge0). + have/splitPr [s3 s4] ->> := r2i. + rewrite big_cat !big_cons /(predT i) /=. + rewrite addmCA; congr; rewrite -big_cat; apply/ih1=> a. + by have := eq_s12 a; rewrite !count_cat /= addzCA => /addzI. +qed. + +(* -------------------------------------------------------------------- *) +lemma eq_big_perm_map (F : 'a -> t) s1 s2: + perm_eq (map F s1) (map F s2) => big predT F s1 = big predT F s2. +proof. +by move=> peq; rewrite -!(@big_map F predT idfun) &(eq_big_perm). +qed. + +(* -------------------------------------------------------------------- *) +lemma big_seq_cond (P : 'a -> bool) (F : 'a -> t) s: + big P F s = big (fun i => mem s i /\ P i) F s. +proof. by rewrite -!(@big_filter _ _ s); congr; apply/eq_in_filter. qed. + +(* -------------------------------------------------------------------- *) +lemma big_seq (F : 'a -> t) s: + big predT F s = big (fun i => mem s i) F s. +proof. by rewrite big_seq_cond; apply/eq_bigl. qed. + +(* -------------------------------------------------------------------- *) +lemma big_rem (P : 'a -> bool) (F : 'a -> t) s x: mem s x => + big P F s = (if P x then F x else idm) + big P F (rem x s). +proof. + by move/perm_to_rem/eq_big_perm=> ->; rewrite !(@big_mkcond P) big_cons. +qed. + +(* -------------------------------------------------------------------- *) +lemma bigD1 (F : 'a -> t) s x: mem s x => uniq s => + big predT F s = F x + big (predC1 x) F s. +proof. by move=> /big_rem-> /rem_filter->; rewrite big_filter. qed. + +(* -------------------------------------------------------------------- *) +lemma bigD1_cond P (F : 'a -> t) s x: P x => mem s x => uniq s => + big P F s = F x + big (predI P (predC1 x)) F s. +proof. +move=> Px sx uqs; rewrite -big_filter (@bigD1 _ _ x) ?big_filter_cond //. + by rewrite mem_filter Px. by rewrite filter_uniq. +qed. + +(* -------------------------------------------------------------------- *) +lemma bigD1_cond_if P (F : 'a -> t) s x: uniq s => big P F s = + (if mem s x /\ P x then F x else idm) + big (predI P (predC1 x)) F s. +proof. +case: (mem s x /\ P x) => [[Px sx]|Nsx]; rewrite ?add0m /=. + by apply/bigD1_cond. +move=> uqs; rewrite big_seq_cond eq_sym big_seq_cond; apply/eq_bigl=> i /=. +by case: (i = x) => @/predC1 @/predI [->>|]. +qed. + +(* -------------------------------------------------------------------- *) +lemma big_split (P : 'a -> bool) (F1 F2 : 'a -> t) s: + big P (fun i => F1 i + F2 i) s = big P F1 s + big P F2 s. +proof. + elim: s=> /= [|x s ih]; 1: by rewrite !big_nil addm0. + rewrite !big_cons ih; case: (P x) => // _. + by rewrite addmCA -!addmA addmCA. +qed. + +(* -------------------------------------------------------------------- *) +lemma bigID (P : 'a -> bool) (F : 'a -> t) (a : 'a -> bool) s: + big P F s = big (predI P a) F s + big (predI P (predC a)) F s. +proof. +rewrite !(@big_mkcond _ F) -big_split; apply/eq_bigr => i _ /=. +by rewrite /predI /predC; case: (a i); rewrite ?addm0 ?add0m. +qed. + +(* -------------------------------------------------------------------- *) +lemma bigU ['a] (P Q : 'a -> bool) (F : 'a -> t) s : (forall x, !(P x /\ Q x)) => + big (predU P Q) F s = big P F s + big Q F s. +proof. +move=> dj_PQ; rewrite (@bigID (predU _ _) _ P). +by congr; apply: eq_bigl => /#. +qed. + +(* -------------------------------------------------------------------- *) +lemma bigEM ['a] (P : 'a -> bool) (F : 'a -> t) s : + big predT F s = big P F s + big (predC P) F s. +proof. by rewrite -bigU 1:/#; apply: eq_bigl => /#. qed. + +(* -------------------------------------------------------------------- *) +lemma big_reindex ['a 'b] + (P : 'a -> bool) (F : 'a -> t) (f : 'b -> 'a) (f' : 'a -> 'b) (s : 'a list) : + (forall x, x \in s => f (f' x) = x) + => big P F s = big (P \o f) (F \o f) (map f' s). +proof. +by move => /eq_in_map id_ff'; rewrite -big_map -map_comp id_ff' id_map. +qed. + +(* -------------------------------------------------------------------- *) +lemma big_pair_pswap ['a 'b] (p : 'a * 'b -> bool) (f : 'a * 'b -> t) s : + big<:'a * 'b> p f s + = big<:'b * 'a> (p \o pswap) (f \o pswap) (map pswap s). +proof. by apply/big_reindex; case. qed. + +(* -------------------------------------------------------------------- *) +lemma eq_big_seq (F1 F2 : 'a -> t) s: + (forall x, mem s x => F1 x = F2 x) + => big predT F1 s = big predT F2 s. +proof. by move=> eqF; rewrite !big_seq; apply/eq_bigr. qed. + +(* -------------------------------------------------------------------- *) +lemma congr_big_seq (P1 P2: 'a -> bool) (F1 F2 : 'a -> t) s: + (forall x, mem s x => P1 x = P2 x) => + (forall x, mem s x => P1 x => P2 x => F1 x = F2 x) + => big P1 F1 s = big P2 F2 s. +proof. + move=> eqP eqH; rewrite big_mkcond eq_sym big_mkcond eq_sym. + apply/eq_big_seq=> x x_in_s /=; rewrite eqP //. + by case (P2 x)=> // P2x; rewrite eqH // eqP. +qed. + +(* -------------------------------------------------------------------- *) +lemma big1_eq (P : 'a -> bool) s: big P (fun (x : 'a) => idm) s = idm. +proof. + rewrite big_const; elim/natind: (count _ _)=> n. + by move/iter0<:t> => ->. + by move/iterS<:t> => -> ->; rewrite addm0. +qed. + +(* -------------------------------------------------------------------- *) +lemma big1 (P : 'a -> bool) (F : 'a -> t) s: + (forall i, P i => F i = idm) => big P F s = idm. +proof. by move/eq_bigr=> ->; apply/big1_eq. qed. + +(* -------------------------------------------------------------------- *) +lemma big1_seq (P : 'a -> bool) (F : 'a -> t) s: + (forall i, P i /\ (mem s i) => F i = idm) => big P F s = idm. +proof. by move=> eqF1; rewrite big_seq_cond big_andbC big1. qed. + +(* -------------------------------------------------------------------- *) +lemma big_eq_idm_filter ['a] (P : 'a -> bool) (F : 'a -> t) s : + (forall (x : 'a), !P x => F x = idm) => big predT F s = big P F s. +proof. +by move=> eq1; rewrite (@bigEM P) (@big1 (predC _)) // addm0. +qed. + +(* -------------------------------------------------------------------- *) +lemma big_flatten (P : 'a -> bool) (F : 'a -> t) rr : + big P F (flatten rr) = big predT (fun s => big P F s) rr. +proof. +elim: rr => /= [|r rr ih]; first by rewrite !big_nil. +by rewrite flatten_cons big_cat big_cons -ih. +qed. + +(* -------------------------------------------------------------------- *) +lemma big_pair ['a 'b] (F : 'a * 'b -> t) (s : ('a * 'b) list) : uniq s => + big predT F s = + big predT (fun a => + big predT F (filter (fun xy : _ * _ => xy.`1 = a) s)) + (undup (map fst s)). +proof. +move=> /perm_eq_pair /eq_big_perm /(_ predT F) ->. +by rewrite big_flatten big_map. +qed. + +(* -------------------------------------------------------------------- *) +lemma big_nseq_cond (P : 'a -> bool) (F : 'a -> t) n x : + big P F (nseq n x) = if P x then iter n ((+) (F x)) idm else idm. +proof. +elim/natind: n => [n le0_n|n ge0_n ih]; first by rewrite ?(nseq0_le, iter0). +by rewrite nseqS // big_cons ih; case: (P x) => //; rewrite iterS. +qed. + +(* -------------------------------------------------------------------- *) +lemma big_nseq (F : 'a -> t) n x : + big predT F (nseq n x) = iter n ((+) (F x)) idm. +proof. by apply/big_nseq_cond. qed. + +(* -------------------------------------------------------------------- *) +lemma big_undup ['a] (P : 'a -> bool) (F : 'a -> t) s : + big P F s = big P (fun a => iter (count (pred1 a) s) ((+) (F a)) idm) (undup s). +proof. +have <- := eq_big_perm P F _ _ (perm_undup_count s). +rewrite big_flatten big_map (@big_mkcond P); apply/eq_big => //=. +by move=> @/(\o) /= x _; apply/big_nseq_cond. +qed. + +(* -------------------------------------------------------------------- *) +lemma exchange_big (P1 : 'a -> bool) (P2 : 'b -> bool) (F : 'a -> 'b -> t) s1 s2: + big P1 (fun a => big P2 (F a) s2) s1 = + big P2 (fun b => big P1 (fun a => F a b) s1) s2. +proof. + elim: s1 s2 => [|a s1 ih] s2; first by rewrite big_nil big1_eq. + rewrite big_cons ih; case: (P1 a)=> h; rewrite -?big_split; + by apply/eq_bigr=> x _ /=; rewrite big_cons h. +qed. + +(* -------------------------------------------------------------------- *) +lemma partition_big ['a 'b] (px : 'a -> 'b) P Q (F : 'a -> t) s s' : + uniq s' + => (forall x, mem s x => P x => mem s' (px x) /\ Q (px x)) + => big P F s = big Q (fun x => big (fun y => P y /\ px y = x) F s) s'. +proof. +move=> uq_s'; elim: s => /~= [|x xs ih] hm. + by rewrite big_nil big1_eq. +rewrite big_cons; case: (P x) => /= [Px|PxN]; last first. + rewrite ih //; 1: by move=> y y_xs; apply/hm; rewrite y_xs. + by apply/eq_bigr=> i _ /=; rewrite big_cons /= PxN. +have := hm x; rewrite Px /= => -[s'_px Qpx]; apply/eq_sym. +rewrite (@bigD1_cond _ _ _ (px x)) //= big_cons /= Px /=. +rewrite -addmA; congr; apply/eq_sym; rewrite ih. + by move=> y y_xs; apply/hm; rewrite y_xs. +rewrite (@bigD1_cond _ _ _ (px x)) //=; congr. +apply/eq_bigr=> /= i [Qi @/predC1]; rewrite eq_sym => ne_pxi. +by rewrite big_cons /= ne_pxi. +qed. + +(* -------------------------------------------------------------------- *) +lemma big_allpairs (f : 'a -> 'b -> 'c) (F : 'c -> t) s u: + big predT F (allpairs<:'a, 'b, 'c> f s u) + = big predT (fun x => big predT (fun y => F (f x y)) u) s. +proof. +elim: s u => [|x s ih] u //=. +by rewrite allpairs_consl big_cat ih big_consT big_map. +qed. + +(* -------------------------------------------------------------------- *) +lemma big_int_cond m n P (F : int -> t): + bigi P F m n = bigi (fun i => m <= i < n /\ P i) F m n. +proof. by rewrite big_seq_cond; apply/eq_bigl=> i /=; rewrite mem_range. qed. + +(* -------------------------------------------------------------------- *) +lemma big_int m n (F : int -> t): + bigi predT F m n = bigi (fun i => m <= i < n) F m n. +proof. by rewrite big_int_cond. qed. + +(* -------------------------------------------------------------------- *) +lemma congr_big_int (m1 n1 m2 n2 : int) P1 P2 (F1 F2 : int -> t): + m1 = m2 => n1 = n2 + => (forall i, m1 <= i < n2 => P1 i = P2 i) + => (forall i, P1 i /\ (m1 <= i < n2) => F1 i = F2 i) + => bigi P1 F1 m1 n1 = bigi P2 F2 m2 n2. +proof. + move=> <- <- eqP12 eqF12; rewrite big_seq_cond (@big_seq_cond P2). + by apply/eq_big=> i /=; rewrite mem_range #smt:(). +qed. + +(* -------------------------------------------------------------------- *) +lemma eq_big_int (m n : int) (F1 F2 : int -> t): + (forall i, m <= i < n => F1 i = F2 i) + => bigi predT F1 m n = bigi predT F2 m n. +proof. by move=> eqF; apply/congr_big_int. qed. + +(* -------------------------------------------------------------------- *) +lemma big_ltn_cond (m n : int) P (F : int -> t): m < n => + let x = bigi P F (m+1) n in + bigi P F m n = if P m then F m + x else x. +proof. by move/range_ltn=> ->; rewrite big_cons. qed. + +(* -------------------------------------------------------------------- *) +lemma big_ltn (m n : int) (F : int -> t): m < n => + bigi predT F m n = F m + bigi predT F (m+1) n. +proof. by move/big_ltn_cond=> /= ->. qed. + +(* -------------------------------------------------------------------- *) +lemma big_geq (m n : int) P (F : int -> t): n <= m => + bigi P F m n = idm. +proof. by move/range_geq=> ->; rewrite big_nil. qed. + +(* -------------------------------------------------------------------- *) +lemma big_addn (m n a : int) P (F : int -> t): + bigi P F (m+a) n + = bigi (fun i => P (i+a)) (fun i => F (i+a)) m (n-a). +proof. +rewrite range_addl big_map; apply/eq_big. + by move=> i /=; rewrite /(\o) addzC. +by move=> i /= _; rewrite /(\o) addzC. +qed. + +(* -------------------------------------------------------------------- *) +lemma big_int1 n (F : int -> t): bigi predT F n (n+1) = F n. +proof. by rewrite big_ltn 1:/# big_geq // addm0. qed. + +(* -------------------------------------------------------------------- *) +lemma big_cat_int (n m p : int) P (F : int -> t): m <= n => n <= p => + bigi P F m p = (bigi P F m n) + (bigi P F n p). +proof. by move=> lemn lenp; rewrite -big_cat -range_cat. qed. + +(* -------------------------------------------------------------------- *) +lemma big_int_recl (n m : int) (F : int -> t): m <= n => + bigi predT F m (n+1) = F m + bigi predT (fun i => F (i+1)) m n. +proof. by move=> lemn; rewrite big_ltn 1?big_addn /= 1:/#. qed. + +(* -------------------------------------------------------------------- *) +lemma big_int_recr (n m : int) (F : int -> t): m <= n => + bigi predT F m (n+1) = bigi predT F m n + F n. +proof. by move=> lemn; rewrite (@big_cat_int n) ?big_int1 //#. qed. + +(* -------------------------------------------------------------------- *) +lemma big_int_recl_cond (n m : int) P (F : int -> t): m <= n => + bigi P F m (n+1) = + (if P m then F m else idm) + + bigi (fun i => P (i+1)) (fun i => F (i+1)) m n. +proof. +by move=> lemn; rewrite big_mkcond big_int_recl //= -big_mkcond. +qed. + +(* -------------------------------------------------------------------- *) +lemma big_int_recr_cond (n m : int) P (F : int -> t): m <= n => + bigi P F m (n+1) = + bigi P F m n + (if P n then F n else idm). +proof. by move=> lemn; rewrite !(@big_mkcond P) big_int_recr. qed. + +(* -------------------------------------------------------------------- *) +lemma bigi_split_odd_even (n : int) (F : int -> t) : 0 <= n => + bigi predT (fun i => F (2 * i) + F (2 * i + 1)) 0 n + = bigi predT F 0 (2 * n). +proof. +move=> ge0_n; rewrite big_split; pose rg := range 0 n. +rewrite -(@big_mapT (fun i => 2 * i)). +rewrite -(@big_mapT (fun i => 2 * i + 1)). +rewrite -big_cat &(eq_big_perm) &(uniq_perm_eq) 2:&(range_uniq). +- rewrite cat_uniq !map_inj_in_uniq /= ~-1:/# range_uniq /=. + apply/hasPn => _ /mapP[y] /= [_ ->]. + by apply/negP; case/mapP=> ? [_] /#. +move=> x; split. +- rewrite mem_cat; case=> /mapP[y] /=; + case=> /mem_range y_rg -> {x}; apply/mem_range; + by smt(). +move/mem_range => x_rg; rewrite mem_cat. +have: forall (i : int), exists j, i = 2 * j \/ i = 2 * j + 1 by smt(). +- case/(_ x) => y [] ->>; [left | right]; apply/mapP=> /=; + by exists y; (split; first apply/mem_range); smt(). +qed. + +end section. + +(* ==================================================================== *) +(* Display wrappers: [bigA] for additive contexts, [bigM] for + multiplicative ones. Both unfold to [big] so all the lemmas above + apply transparently. The flavor tag on the carrier ([addmonoid] + vs [mulmonoid]) drives which wrapper the printer folds back to. *) +(* ==================================================================== *) +abbrev bigA ['a, 't <: addmonoid] P (F : 'a -> 't) r = big P F r. +abbrev bigM ['a, 't <: mulmonoid] P (F : 'a -> 't) r = big P F r. + +abbrev bigiA ['t <: addmonoid] (P : int -> bool) (F : int -> 't) i j = bigA P F (range i j). +abbrev bigiM ['t <: mulmonoid] (P : int -> bool) (F : int -> 't) i j = bigM P F (range i j). diff --git a/examples/tcalgebra/TcInt.ec b/examples/tcalgebra/TcInt.ec new file mode 100644 index 0000000000..f0f1466c90 --- /dev/null +++ b/examples/tcalgebra/TcInt.ec @@ -0,0 +1,78 @@ +pragma +implicits. + +(* -------------------------------------------------------------------- *) +require import Core. +require import TcMonoid TcRing. +require import Int. +require CoreInt. + +(* ==================================================================== *) +(* Canonical [int] instance for the [TcMonoid] / [TcRing] hierarchy. + Mirrors [theories/algebra/Ring.ec:IntID]. *) +(* ==================================================================== *) + +(* Named wrappers for [int]'s [unit] / [invr]: the TC instance form + requires an op-name on the rhs of [op X = …], not an inline lambda. *) +op int_unit (z : int) : bool = z = 1 \/ z = -1. +op int_invr (z : int) : int = z. + +(* -------------------------------------------------------------------- *) +(* Declaring [idomain] synthesises [comring] (and the rest of the + chain) along the way, so we don't need a separate [instance comring + with int] — declaring both would create duplicate comring witnesses + for [int] and break op-name resolution downstream. *) +instance idomain with int + op idm = CoreInt.zero + op (+) = CoreInt.add + op [-] = CoreInt.opp + op oner = CoreInt.one + op ( * ) = CoreInt.mul + op invr = int_invr + op unit = int_unit + + proof addmA by smt() + proof addmC by smt() + proof add0m by smt() + proof addrN by smt() + proof oner_neq0 by smt() + proof mulrA by smt() + proof mulrC by smt() + proof mul1r by smt() + proof mulrDl by smt() + proof mulVr by smt(@CoreInt) + proof unitP by smt() + proof unitout by smt() + proof mulf_eq0 by smt(). + +op _spacer1 : int = 0. + +(* ==================================================================== *) +(* int-specific corollaries that sit on top of the [comring] / + [idomain] instances. Mirrors the lemmas under [Ring.ec:IntID]. *) +(* ==================================================================== *) + +(* int's abstract [intmul] coincides with concrete int multiplication. *) +lemma intmulz (z c : int) : intmul z c = Int.( * ) z c. +proof. +have h: forall cp, 0 <= cp => intmul z cp = Int.( * ) z cp. + elim=> /= [|cp ge0_cp ih]; first by rewrite mulr0z. + by rewrite mulrS // ih /#. +smt(opprK mulrNz opprK). +qed. + +(* Parity of [exp x n] for [x : int] tracks parity of [x] when [n > 0]. *) +lemma poddX (n x : int) : + 0 < n => odd (exp x n) = odd x. +proof. +rewrite ltz_def => - [] + ge0_n; elim: n ge0_n => // + + _ _. +elim=> [|n ge0_n ih]; first by rewrite expr1. +by rewrite exprS ?addz_ge0 // oddM ih andbb. +qed. + +lemma oddX (n x : int) : + 0 <= n => odd (exp x n) = (odd x \/ n = 0). +proof. +rewrite lez_eqVlt; case: (n = 0) => [->// _|+ h]. ++ by rewrite expr0 odd1. ++ by case: h => [<-//|] /poddX ->. +qed. diff --git a/examples/tcalgebra/TcMonoid.ec b/examples/tcalgebra/TcMonoid.ec new file mode 100644 index 0000000000..0b9f83c3dd --- /dev/null +++ b/examples/tcalgebra/TcMonoid.ec @@ -0,0 +1,57 @@ +require import Int. + +(* ==================================================================== *) +(* Abstract monoid: where all the lemmas live, written once. *) +(* ==================================================================== *) +type class monoid = { + op idm : monoid + op (+) : monoid -> monoid -> monoid + + axiom addmA : associative (+) + axiom addmC : commutative (+) + axiom add0m : left_id idm (+) +}. + +(* -------------------------------------------------------------------- *) +section. +declare type t <: monoid. + +lemma addm0: right_id idm<:t> (+). +proof. by move=> x; rewrite addmC add0m. qed. + +lemma addmCA: left_commutative (+)<:t>. +proof. by move=> x y z; rewrite !addmA (addmC x). qed. + +lemma addmAC: right_commutative (+)<:t>. +proof. by move=> x y z; rewrite -!addmA (addmC y). qed. + +lemma addmACA: interchange (+)<:t> (+). +proof. by move=> x y z t; rewrite -!addmA (addmCA y). qed. + +lemma iteropE n (x : t): iterop n (+) x idm = iter n ((+) x) idm. +proof. +elim/natcase n => [n le0_n|n ge0_n]. ++ by rewrite ?(iter0, iterop0). ++ by rewrite iterSr // addm0 iteropS. +qed. +end section. + +(* ==================================================================== *) +(* Flavor tags: empty subclasses of monoid. They carry no extra + structure; their only purpose is to drive display (\sum vs \prod + for bigops, [zero]/[+] vs [one]/[*] for the operators). *) +(* ==================================================================== *) +type class addmonoid <: monoid = {}. + +type class mulmonoid <: monoid = {}. + +(* -------------------------------------------------------------------- *) +(* Source-level renamings on top of [monoid]'s operators. Each abbrev is + a transparent alias; it parses to the underlying monoid op and prints + back as the alias when the type carries the matching flavor tag. *) +(* -------------------------------------------------------------------- *) +abbrev zero ['a <: addmonoid] : 'a = idm<:'a>. + +abbrev one ['a <: mulmonoid] : 'a = idm<:'a>. + +abbrev ( * ) ['a <: mulmonoid] (x y : 'a) = (+)<:'a> x y. diff --git a/examples/tcalgebra/TcNumber.ec b/examples/tcalgebra/TcNumber.ec new file mode 100644 index 0000000000..cb91620690 --- /dev/null +++ b/examples/tcalgebra/TcNumber.ec @@ -0,0 +1,1523 @@ +pragma +implicits. + +(* -------------------------------------------------------------------- *) +require import Core Int AlgTactic StdRing. +require import TcMonoid TcRing. +require import TcInt. + +(* -------------------------------------------------------------------- *) +pred homo2 ['a 'b] (op_ : 'a -> 'b) (aR : 'a rel) (rR : 'b rel) = + forall x y, aR x y => rR (op_ x) (op_ y). + +pred mono2 ['a 'b] (op_ : 'a -> 'b) (aR : 'a rel) (rR : 'b rel) = + forall x y, rR (op_ x) (op_ y) <=> aR x y. + +lemma mono2W f (aR : 'a rel) (rR : 'b rel) : + mono2 f aR rR => homo2 f aR rR. +proof. by move=> + x y - ->. qed. + +lemma monoLR ['a 'b] f g (aR : 'a rel) (rR : 'b rel) : + cancel g f => mono2 f aR rR => forall x y, + rR (f x) y <=> aR x (g y). +proof. by move=> can_gf mf x y; rewrite -{1}[y]can_gf mf. qed. + +lemma monoRL ['a 'b] f g (aR : 'a rel) (rR : 'b rel) : + cancel g f => mono2 f aR rR => forall x y, + rR x (f y) <=> aR (g x) y. +proof. by move=> can_gf mf x y; rewrite -{1}can_gf mf. qed. + +(* ==================================================================== *) +(* Real-closed domain: ordered integral domain with norm. Mirrors *) +(* [theories/algebra/Number.ec:RealDomain] but as a TC class on top *) +(* of [idomain]. *) +(* ==================================================================== *) +type class tcrealdomain <: idomain = { + op "`|_|" : tcrealdomain -> tcrealdomain + op ( <= ) : tcrealdomain -> tcrealdomain -> bool + op ( < ) : tcrealdomain -> tcrealdomain -> bool + op minr : tcrealdomain -> tcrealdomain -> tcrealdomain + op maxr : tcrealdomain -> tcrealdomain -> tcrealdomain + + axiom ler_norm_add : + forall (x y : tcrealdomain), `|x + y| <= `|x| + `|y| + axiom addr_gt0 : + forall (x y : tcrealdomain), zero<:tcrealdomain> < x => zero < y => zero < x + y + axiom norm_eq0 : + forall (x : tcrealdomain), `|x| = zero<:tcrealdomain> => x = zero + axiom ger_leVge : + forall (x y : tcrealdomain), + zero<:tcrealdomain> <= x => zero <= y => (x <= y) \/ (y <= x) + axiom normrM : + forall (x y : tcrealdomain), `|x * y| = `|x| * `|y| + axiom ler_def : + forall (x y : tcrealdomain), x <= y <=> `|y - x| = y - x + axiom ltr_def : + forall (x y : tcrealdomain), x < y <=> (y <> x) /\ x <= y + axiom real_axiom : + forall (x : tcrealdomain), zero<:tcrealdomain> <= x \/ x <= zero + axiom minrE : + forall (x y : tcrealdomain), minr x y = if x <= y then x else y + axiom maxrE : + forall (x y : tcrealdomain), maxr x y = if y <= x then x else y +}. + +(* -------------------------------------------------------------------- *) +section. +declare type t <: tcrealdomain. + +(* -------------------------------------------------------------------- *) +(* Sign / positivity / order reflexivity *) +(* -------------------------------------------------------------------- *) + +lemma ger0_def (x : t): (zero <= x) <=> (`|x| = x). +proof. by rewrite ler_def subr0. qed. + +lemma subr_ge0 (x y : t): (zero <= x - y) <=> (y <= x). +proof. by rewrite ger0_def -ler_def. qed. + +lemma oppr_ge0 (x : t): (zero <= -x) <=> (x <= zero). +proof. by rewrite -sub0r subr_ge0. qed. + +lemma ler01: zero<:t> <= oner. +proof. +have n1_nz: `|oner<:t>| <> zero. ++ apply/(contraNneq _ _ (oner_neq0<:t>)) => /norm_eq0->; trivial. +by rewrite ger0_def -(inj_eq (mulfI _ n1_nz)) -normrM !mulr1. +qed. + +lemma ltr01: zero<:t> < oner. +proof. by rewrite ltr_def oner_neq0 ler01. qed. + +hint exact : ler01 ltr01. + +lemma ltrW (x y : t): x < y => x <= y. +proof. by rewrite ltr_def. qed. + +lemma lerr (x : t): x <= x. +proof. +have n2: `|ofint<:t> 2| = ofint 2. + rewrite -ger0_def (@ofintS 1) // ofint1 ltrW //. + by rewrite addr_gt0 ?ltr01. +rewrite ler_def subrr -(inj_eq (addrI `|zero<:t>|)) /= addr0. +by rewrite -mulr2z -mulr_intr -n2 -normrM mul0r. +qed. + +hint exact : lerr. + +lemma lerr_eq (x y : t): x = y => x <= y. +proof. by move=> ->; rewrite lerr. qed. + +lemma ltrr (x : t): !(x < x). +proof. by rewrite ltr_def. qed. + +lemma ltr_neqAle (x y : t): + (x < y) <=> (x <> y) /\ (x <= y). +proof. by rewrite ltr_def eq_sym. qed. + +lemma ler_eqVlt (x y : t): + (x <= y) <=> (x = y) \/ (x < y). +proof. by rewrite ltr_neqAle; case: (x = y)=> // ->; rewrite lerr. qed. + +lemma lt0r (x : t): + (zero < x) <=> (x <> zero) /\ (zero <= x). +proof. by rewrite ltr_def. qed. + +lemma le0r (x : t): + (zero <= x) <=> (x = zero) \/ (zero < x). +proof. by rewrite ler_eqVlt eq_sym. qed. + +lemma addr_ge0 (x y : t): + zero <= x => zero <= y => zero <= x + y. +proof. +rewrite le0r; case=> [->|gt0x]; rewrite ?add0r // le0r. +by case=> [->|gt0y]; rewrite ltrW ?addr0 ?addr_gt0. +qed. + +lemma lt0r_neq0 (x : t): + zero < x => (x <> zero). +proof. by rewrite lt0r; case (_ = _). qed. + +lemma ltr0_neq0 (x : t): + zero < x => (x <> zero). +proof. by rewrite lt0r; case: (_ = _). qed. + +lemma gtr_eqF (x y : t): + y < x => (x <> y). +proof. by rewrite ltr_def => -[]. qed. + +lemma ltr_eqF (x y : t): + x < y => (x <> y). +proof. by rewrite eq_sym=> /gtr_eqF ->. qed. + +lemma ler0n n : 0 <= n => zero<:t> <= ofint n. +proof. +elim: n => [|n ih h]; first by rewrite ofint0 lerr. +by rewrite ofintS // addr_ge0 // ?ler01. +qed. + +lemma ltr0Sn n : 0 <= n => zero<:t> < ofint (n + 1). +proof. +elim: n=> /= [|n ge0n ih]; first by rewrite ofint1 ltr01. +by rewrite (@ofintS (n+1)) // ?(addz_ge0, addr_gt0) // ltr01. +qed. + +lemma ltr0n n : 0 <= n => (zero<:t> < ofint n) = (0 < n). +proof. +elim: n => [|n ge0n _]; first by rewrite ofint0 ltrr. +by rewrite ltr0Sn // ltz_def addz_ge0 ?addz1_neq0. +qed. + +lemma pnatr_eq0 n : 0 <= n => (ofint<:t> n = zero) <=> (n = 0). +proof. +elim: n => [|n ge0n _]; rewrite ?ofint0 // gtr_eqF. + by apply: ltr0Sn. by rewrite addz1_neq0. +qed. + +lemma pmulr_rgt0 (x y : t): + zero < x => (zero < x * y) <=> (zero < y). +proof. +rewrite !ltr_def !ger0_def normrM mulf_eq0 negb_or. +by case=> ^nz_x -> -> /=; have /inj_eq -> := mulfI _ nz_x. +qed. + +lemma pmulr_rge0 (x y : t): + zero < x => (zero <= x * y) <=> (zero <= y). +proof. +rewrite !le0r mulf_eq0; case: (y = _) => //= ^lt0x. +by move/lt0r_neq0=> -> /=; apply/pmulr_rgt0. +qed. + +lemma normr_idP (x : t): (`|x| = x) <=> (zero <= x). +proof. by rewrite ger0_def. qed. + +lemma ger0_norm (x : t): zero <= x => `|x| = x. +proof. by apply/normr_idP. qed. + +lemma normr0: `|zero<:t>| = zero. +proof. by apply/ger0_norm/lerr. qed. + +lemma normr1: `|oner<:t>| = oner. +proof. by apply/ger0_norm/ler01. qed. + +lemma normr_nat n : 0 <= n => `|ofint<:t> n| = ofint n. +proof. by move=> n_0ge; rewrite ger0_norm // ler0n. qed. + +lemma normr0_eq0 (x : t): `|x| = zero => x = zero. +proof. by apply/norm_eq0. qed. + +lemma normr0P (x : t): (`|x| = zero) <=> (x = zero). +proof. by split=> [/norm_eq0|->] //; rewrite normr0. qed. + +lemma normrX_nat n (x : t) : 0 <= n => `|exp x n| = exp `|x| n. +proof. +elim: n=> [|n ge0_n ih]; first by rewrite !expr0 normr1. +by rewrite !exprS //= normrM ih. +qed. + +lemma normrN1: `|-oner<:t>| = oner. +proof. +have: exp `|-oner<:t>| 2 = oner. + by rewrite -normrX_nat -1?signr_odd // odd2 expr0 normr1. +rewrite sqrf_eq1=> -[->//|]; rewrite -ger0_def le0r oppr_eq0. +by rewrite oner_neq0 /= => /(addr_gt0 _ _ ltr01); rewrite addrN ltrr. +qed. + +lemma normrZ (x y : t) : zero <= x => `| x * y | = x * `| y |. +proof. by move=> ge0; rewrite normrM ger0_norm. qed. + +lemma normrN (x : t): `|- x| = `|x|. +proof. by rewrite -mulN1r normrM normrN1 mul1r. qed. + +lemma distrC (x y : t): `|x - y| = `|y - x|. +proof. by rewrite -opprB normrN. qed. + +lemma ler0_def (x : t): (x <= zero) <=> (`|x| = - x). +proof. by rewrite ler_def sub0r normrN. qed. + +lemma normr_unit : forall (x : t), unit x => unit `|x|. +proof. +move=> x; rewrite !unitrP => -[y yx]. +by exists `|y|; rewrite -normrM yx normr1. +qed. + +lemma ler0_norm (x : t): x <= zero => `|x| = - x. +proof. +move=> x_le0; rewrite eq_sym -(@ger0_norm (-x)). + by rewrite oppr_ge0. by rewrite normrN. +qed. + +lemma unit_normr (x : t): unit (`|x|) => unit x. +proof. +case: (real_axiom x) => [le0n|len0]. + by move: (normr_idP x); rewrite le0n /= => ->. +by rewrite ler0_norm // unitrN. +qed. + +lemma normrV : forall (x : t), `|invr x| = invr `|x|. +proof. +move=>x. +case: (unit x) => ux. ++ apply/(@mulrI `|x|); 1: by apply/normr_unit. + by rewrite -normrM !mulrV ?normr_unit // normr1. +rewrite !unitout //; apply: contra ux. +by apply unit_normr. +qed. + +lemma normr_id (x : t): `| `|x| | = `|x|. +proof. +have nz2: ofint<:t> 2 <> zero by rewrite pnatr_eq0. +apply: (mulfI _ nz2); rewrite -{1}normr_nat // -normrM. +rewrite mulr_intl mulr2z ger0_norm // -{2}normrN. +by rewrite -normr0 -(@subrr x) ler_norm_add. +qed. + +lemma normr_ge0 (x : t): zero <= `|x|. +proof. by rewrite ger0_def normr_id. qed. + +lemma gtr0_norm (x : t): zero < x => `|x| = x. +proof. by move/ltrW/ger0_norm. qed. + +lemma ltr0_norm (x : t): x < zero => `|x| = - x. +proof. by move/ltrW/ler0_norm. qed. + +lemma subr_gt0 (x y : t): (zero < y - x) <=> (x < y). +proof. by rewrite !ltr_def subr_eq0 subr_ge0. qed. + +lemma subr_le0 (x y : t): (y - x <= zero) <=> (y <= x). +proof. by rewrite -subr_ge0 opprB add0r subr_ge0. qed. + +lemma subr_lt0 (x y : t): (y - x < zero) <=> (y < x). +proof. by rewrite -subr_gt0 opprB add0r subr_gt0. qed. + +lemma ler_asym (x y : t): x <= y <= x => x = y. +proof. +rewrite !ler_def distrC -opprB -addr_eq0 => -[->]. +by rewrite -mulr2z -mulr_intl mulf_eq0 subr_eq0 pnatr_eq0. +qed. + +lemma eqr_le (x y : t): (x = y) <=> (x <= y <= x). +proof. by split=> [->|/ler_asym]; rewrite ?lerr. qed. + +lemma ltr_trans (y x z : t): x < y => y < z => x < z. +proof. +move=> le_xy le_yz; rewrite -subr_gt0 -(@subrK z y). +by rewrite -addrA addr_gt0 ?subr_gt0. +qed. + +lemma ler_lt_trans (y x z : t): x <= y => y < z => x < z. +proof. by rewrite !ler_eqVlt => -[-> //|/ltr_trans h]; apply/h. qed. + +lemma ltr_le_trans (y x z : t): x < y => y <= z => x < z. +proof. by rewrite !ler_eqVlt => lxy [<- //|lyz]; apply (@ltr_trans y). qed. + +lemma ler_trans (y x z : t): x <= y => y <= z => x <= z. +proof. +rewrite !ler_eqVlt => -[-> //|lxy] [<-|]. + by rewrite lxy. by move/(ltr_trans _ _ _ lxy) => ->. +qed. + +lemma ltr_asym (x y : t): ! (x < y < x). +proof. by apply/negP=> -[/ltr_trans hyx /hyx]; rewrite ltrr. qed. + +lemma ler_anti (x y : t): x <= y <= x => x = y. +proof. by rewrite -eqr_le. qed. + +lemma ltr_le_asym (x y : t): ! (x < y <= x). +proof. +rewrite andaE ltr_neqAle -andbA -!andaE. +by rewrite -eqr_le eq_sym; case: (_ = _). +qed. + +lemma ler_lt_asym (x y : t): + ! (x <= y < x). +proof. by rewrite andaE andbC -andaE ltr_le_asym. qed. + +lemma ltr_geF (x y : t): x < y => ! (y <= x). +proof. by move=> xy; apply/negP => /(ltr_le_trans _ _ _ xy); rewrite ltrr. qed. + +lemma ler_gtF (x y : t): x <= y => ! (y < x). +proof. by move=> le_xy; apply/negP=> /ltr_geF. qed. + +lemma ltr_gtF (x y : t): x < y => ! (y < x). +proof. by move/ltrW/ler_gtF. qed. + +lemma normr_le0 (x : t): (`|x| <= zero) <=> (x = zero). +proof. by rewrite -normr0P eqr_le normr_ge0. qed. + +lemma normr_lt0 (x : t): ! (`|x| < zero). +proof. by rewrite ltr_neqAle normr_le0 normr0P; case: (_ = _). qed. + +lemma normr_gt0 (x : t): (zero < `|x|) <=> (x <> zero). +proof. by rewrite ltr_def normr0P normr_ge0; case: (_ = _). qed. + +lemma normrX n (x : t) : `|exp x n| = exp `|x| n. +proof. +case (0 <= n); [by apply normrX_nat|]. +rewrite -ltzNge -{1}(invrK x) exprV => ltn0. +rewrite normrX_nat; [by rewrite oppz_ge0 ltzW|]. +case: (unit x) => [unitx|Nunitx]. + by rewrite normrV // exprV. +move: (unit_normr x) => /contra; rewrite Nunitx /=. +move => unitNx; rewrite invr_out //. +by rewrite -{1}(@invr_out `|_|) // exprV. +qed. + +(*-------------------------------------------------------------------- *) +hint rewrite normrE : normr_id normr0 normr1 normrN1. +hint rewrite normrE : normr_ge0 normr_lt0 normr_le0 normr_gt0. +hint rewrite normrE : normrN. + +(* -------------------------------------------------------------------- *) +lemma mono_inj (f : t -> t) : mono2 f (<=) (<=) => injective f. +proof. by move=> mf x y; rewrite eqr_le !mf -eqr_le. qed. + +lemma nmono_inj (f : t -> t) : mono2 f (fun y x => x <= y) (<=) => injective f. +proof. by move=> mf x y; rewrite eqr_le !mf -eqr_le. qed. + +lemma lerW_mono (f : t -> t) : mono2 f (<=) (<=) => mono2 f (<) (<). +proof. +move=> mf x y; rewrite !ltr_neqAle mf. +by rewrite inj_eq //; apply/mono_inj. +qed. + +lemma lerW_nmono (f : t -> t) : + mono2 f (fun y x => x <= y) (<=) + => mono2 f (fun y x => x < y) (<). +proof. +move=> mf x y; rewrite !ltr_neqAle mf eq_sym. +by rewrite inj_eq //; apply/nmono_inj. +qed. + +(* -------------------------------------------------------------------- *) +lemma ler_opp2 (x y : t): (-x <= -y) <=> (y <= x). +proof. by rewrite -subr_ge0 opprK addrC subr_ge0. qed. + +lemma ltr_opp2 (x y : t): (-x < -y) <=> (y < x). +proof. by rewrite lerW_nmono //; apply/ler_opp2. qed. + +lemma ler_oppr (x y : t): (x <= - y) <=> (y <= - x). +proof. by rewrite (monoRL opprK ler_opp2). qed. + +hint rewrite lter_opp2 : ler_opp2 ltr_opp2. + +lemma ltr_oppr (x y : t): (x < - y) <=> (y < - x). +proof. by rewrite (monoRL opprK (:@lerW_nmono _ ler_opp2)). qed. + +lemma ler_oppl (x y : t): + (- x <= y) <=> (- y <= x). +proof. by rewrite (monoLR opprK ler_opp2). qed. + +lemma ltr_oppl (x y : t): + (- x < y) <=> (- y < x). +proof. by rewrite (monoLR opprK (:@lerW_nmono _ ler_opp2)). qed. + +lemma oppr_gt0 (x : t): (zero < - x) <=> (x < zero). +proof. by rewrite ltr_oppr oppr0. qed. + +lemma oppr_le0 (x : t): (- x <= zero) <=> (zero <= x). +proof. by rewrite ler_oppl oppr0. qed. + +lemma oppr_lt0 (x : t): (- x < zero) <=> (zero < x). +proof. by rewrite ltr_oppl oppr0. qed. + +hint rewrite oppr_gte0 : oppr_ge0 oppr_gt0. +hint rewrite oppr_lte0 : oppr_le0 oppr_lt0. +hint rewrite oppr_cp0 : oppr_ge0 oppr_gt0 oppr_le0 oppr_lt0. +hint rewrite lter_oppE : oppr_le0 oppr_lt0 oppr_ge0 oppr_gt0. +hint rewrite lter_oppE : ler_opp2 ltr_opp2. + +(* -------------------------------------------------------------------- *) +lemma ler_leVge (x y : t): + x <= zero => y <= zero => (x <= y) \/ (y <= x). +proof. by rewrite -!oppr_ge0 => /(ger_leVge _) h /h; rewrite !ler_opp2 orbC. qed. + +lemma ler_add2l (x y z : t) : (x + y <= x + z) <=> (y <= z). +proof. by rewrite -subr_ge0 opprD addrAC addNKr addrC subr_ge0. qed. + +lemma ler_add2r (x y z : t) : (y + x <= z + x) <=> (y <= z). +proof. by rewrite !(@addrC _ x) ler_add2l. qed. + +lemma ltr_add2r (z x y : t): (x + z < y + z) <=> (x < y). +proof. by apply/(@lerW_mono (fun u => u + z) (:@ler_add2r z)). qed. + +lemma ltr_add2l (z x y : t): (z + x < z + y) <=> (x < y). +proof. by apply/(@lerW_mono (fun u => z + u) (:@ler_add2l z)). qed. + +hint rewrite ler_add2 : ler_add2l ler_add2r. +hint rewrite ltr_add2 : ltr_add2l ltr_add2r. +hint rewrite lter_add2 : ler_add2l ler_add2r ltr_add2l ltr_add2r. + +lemma ler_add (x y z u : t): + x <= y => z <= u => x + z <= y + u. +proof. by move=> xy zt; rewrite (@ler_trans (y + z)) ?lter_add2. qed. + +lemma ler_lt_add (x y z u : t): + x <= y => z < u => x + z < y + u. +proof. by move=> xy zt; rewrite (@ler_lt_trans (y + z)) ?lter_add2. qed. + +lemma ltr_le_add (x y z u : t): + x < y => z <= u => x + z < y + u. +proof. by move=> xy zt; rewrite (@ltr_le_trans (y + z)) ?lter_add2. qed. + +lemma ltr_add (x y z u : t): x < y => z < u => x + z < y + u. +proof. by move=> xy zt; rewrite ltr_le_add // ltrW. qed. + +lemma ler_sub (x y z u : t): + x <= y => u <= z => x - z <= y - u. +proof. by move=> xy tz; rewrite ler_add ?lter_opp2. qed. + +lemma ler_lt_sub (x y z u : t): + x <= y => u < z => x - z < y - u. +proof. by move=> xy zt; rewrite ler_lt_add ?lter_opp2. qed. + +lemma ltr_le_sub (x y z u : t): + x < y => u <= z => x - z < y - u. +proof. by move=> xy zt; rewrite ltr_le_add ?lter_opp2. qed. + +lemma ltr_sub (x y z u : t): + x < y => u < z => x - z < y - u. +proof. by move=> xy tz; rewrite ltr_add ?lter_opp2. qed. + +lemma ler_subl_addr (x y z : t): + (x - y <= z) <=> (x <= z + y). +proof. by rewrite (monoLR (:@addrK y) (:@ler_add2r (-y))). qed. + +lemma ltr_subl_addr (x y z : t): + (x - y < z) <=> (x < z + y). +proof. by rewrite (monoLR (:@addrK y) (:@ltr_add2r (-y))). qed. + +lemma ler_subr_addr (x y z : t): + (x <= y - z) <=> (x + z <= y). +proof. by rewrite (monoLR (:@addrNK z) (:@ler_add2r z)). qed. + +lemma ltr_subr_addr (x y z : t): + (x < y - z) <=> (x + z < y). +proof. by rewrite (monoLR (:@addrNK z) (:@ltr_add2r z)). qed. + +hint rewrite ler_sub_addr : ler_subl_addr ler_subr_addr. +hint rewrite ltr_sub_addr : ltr_subl_addr ltr_subr_addr. +hint rewrite lter_sub_addr : ler_subl_addr ler_subr_addr. +hint rewrite lter_sub_addr : ltr_subl_addr ltr_subr_addr. + +lemma ler_subl_addl (x y z : t): + (x - y <= z) <=> (x <= y + z). +proof. by rewrite lter_sub_addr addrC. qed. + +lemma ltr_subl_addl (x y z : t): + (x - y < z) <=> (x < y + z). +proof. by rewrite lter_sub_addr addrC. qed. + +lemma ler_subr_addl (x y z : t): + (x <= y - z) <=> (z + x <= y). +proof. by rewrite lter_sub_addr addrC. qed. + +lemma ltr_subr_addl (x y z : t): + (x < y - z) <=> (z + x < y). +proof. by rewrite lter_sub_addr addrC. qed. + +hint rewrite ler_sub_addl : ler_subl_addl ler_subr_addl. +hint rewrite ltr_sub_addl : ltr_subl_addl ltr_subr_addl. +hint rewrite lter_sub_addl : ler_subl_addl ler_subr_addl. +hint rewrite lter_sub_addl : ltr_subl_addl ltr_subr_addl. + +lemma ler_addl (x y : t): (x <= x + y) <=> (zero <= y). +proof. by rewrite -{1}(@addr0 x) lter_add2. qed. + +lemma ltr_addl (x y : t): (x < x + y) <=> (zero < y). +proof. by rewrite -{1}(@addr0 x) lter_add2. qed. + +lemma ler_addr (x y : t): (x <= y + x) <=> (zero <= y). +proof. by rewrite -{1}(@add0r x) lter_add2. qed. + +lemma ltr_addr (x y : t): (x < y + x) <=> (zero < y). +proof. by rewrite -{1}(@add0r x) lter_add2. qed. + +lemma ger_addl (x y : t): (x + y <= x) <=> (y <= zero). +proof. by rewrite -{2}(@addr0 x) lter_add2. qed. + +lemma gtr_addl (x y : t): (x + y < x) <=> (y < zero). +proof. by rewrite -{2}(@addr0 x) lter_add2. qed. + +lemma ger_addr (x y : t): (y + x <= x) <=> (y <= zero). +proof. by rewrite -{2}(@add0r x) lter_add2. qed. + +lemma gtr_addr (x y : t): (y + x < x) <=> (y < zero). +proof. by rewrite -{2}(@add0r x) lter_add2. qed. + +hint rewrite cpr_add : ler_addl ler_addr ger_addl ger_addl. +hint rewrite cpr_add : ltr_addl ltr_addr gtr_addl gtr_addl. + +lemma ler_paddl (y x z : t): + zero <= x => y <= z => y <= x + z. +proof. by move=> ??; rewrite -(@add0r y) ler_add. qed. + +lemma ltr_paddl (y x z : t): + zero <= x => y < z => y < x + z. +proof. by move=> ??; rewrite -(@add0r y) ler_lt_add. qed. + +lemma ltr_spaddl (y x z : t): + zero < x => y <= z => y < x + z. +proof. by move=> ??; rewrite -(@add0r y) ltr_le_add. qed. + +lemma ltr_spsaddl (y x z : t): + zero < x => y < z => y < x + z. +proof. by move=> ??; rewrite -(@add0r y) ltr_add. qed. + +lemma ler_naddl (y x z : t): + x <= zero => y <= z => x + y <= z. +proof. by move=> ??; rewrite -(@add0r z) ler_add. qed. + +lemma ltr_naddl (y x z : t): + x <= zero => y < z => x + y < z. +proof. by move=> ??; rewrite -(@add0r z) ler_lt_add. qed. + +lemma ltr_snaddl (y x z : t): + x < zero => y <= z => x + y < z. +proof. by move=> ??; rewrite -(@add0r z) ltr_le_add. qed. + +lemma ltr_snsaddl (y x z : t): + x < zero => y < z => x + y < z. +proof. by move=> ??; rewrite -(@add0r z) ltr_add. qed. + +lemma ler_paddr (y x z : t): + zero <= x => y <= z => y <= z + x. +proof. by move=> ??; rewrite (@addrC _ x) ler_paddl. qed. + +lemma ltr_paddr (y x z : t): + zero <= x => y < z => y < z + x. +proof. by move=> ??; rewrite (@addrC _ x) ltr_paddl. qed. + +lemma ltr_spaddr (y x z : t): + zero < x => y <= z => y < z + x. +proof. by move=> ??; rewrite (@addrC _ x) ltr_spaddl. qed. + +lemma ltr_spsaddr (y x z : t): + zero < x => y < z => y < z + x. +proof. by move=> ??; rewrite (@addrC _ x) ltr_spsaddl. qed. + +lemma ler_naddr (y x z : t): + x <= zero => y <= z => y + x <= z. +proof. by move=> ??; rewrite (@addrC _ x) ler_naddl. qed. + +lemma ltr_naddr (y x z : t): + x <= zero => y < z => y + x < z. +proof. by move=> ??; rewrite (@addrC _ x) ltr_naddl. qed. + +lemma ltr_snaddr (y x z : t): + x < zero => y <= z => y + x < z. +proof. by move=> ??; rewrite (@addrC _ x) ltr_snaddl. qed. + +lemma ltr_snsaddr (y x z : t): + x < zero => y < z => y + x < z. +proof. by move=> ??; rewrite (@addrC _ x) ltr_snsaddl. qed. + +(* -------------------------------------------------------------------- *) +lemma paddr_eq0 (x y : t): + zero <= x => zero <= y => (x + y = zero) <=> (x = zero) /\ (y = zero). +proof. +rewrite le0r=> -[->|hx]; first by rewrite add0r. +by rewrite (gtr_eqF hx) /= => hy; rewrite gtr_eqF // ltr_spaddl. +qed. + +lemma naddr_eq0 (x y : t): + x <= zero => y <= zero => (x + y = zero) <=> (x = zero) /\ (y = zero). +proof. +by move=> lex0 ley0; rewrite -oppr_eq0 opprD paddr_eq0 ?oppr_cp0 // !oppr_eq0. +qed. + +lemma addr_ss_eq0 (x y : t): + (zero <= x) /\ (zero <= y) \/ + (x <= zero) /\ (y <= zero) => + (x + y = zero) <=> (x = zero) /\ (y = zero). +proof. by case=> -[]; [apply: paddr_eq0 | apply: naddr_eq0]. qed. + +(* -------------------------------------------------------------------- *) +lemma ler_pmul2l (x : t) : + zero < x => forall y z, (x * y <= x * z) <=> (y <= z). +proof. +move=> x_gt0 y z /=; rewrite -subr_ge0 -mulrBr. +by rewrite pmulr_rge0 // subr_ge0. +qed. + +lemma ltr_pmul2l (x : t) : + zero < x => forall y z, (x * y < x * z) <=> (y < z). +proof. by move=> x_gt0; apply/lerW_mono/ler_pmul2l. qed. + +hint rewrite lter_pmul2l : ler_pmul2l ltr_pmul2l. + +lemma ler_pmul2r (x : t) : + zero < x => forall y z, (y * x <= z * x) <=> (y <= z). +proof. by move=> x_gt0 y z /=; rewrite !(@mulrC _ x) ler_pmul2l. qed. + +lemma ltr_pmul2r (x : t) : + zero < x => forall y z, (y * x < z * x) <=> (y < z). +proof. by move=> x_gt0; apply/lerW_mono/ler_pmul2r. qed. + +hint rewrite lter_pmul2r : ler_pmul2r ltr_pmul2r. + +lemma ler_nmul2l (x : t) : + x < zero => forall y z, (x * y <= x * z) <=> (z <= y). +proof. by move=> x_lt0 y z /=; rewrite -ler_opp2 -!mulNr ler_pmul2l ?oppr_gt0. qed. + +lemma ltr_nmul2l (x : t) : + x < zero => forall y z, (x * y < x * z) <=> (z < y). +proof. by move=> x_lt0; apply/lerW_nmono/ler_nmul2l. qed. + +hint rewrite lter_nmul2l : ler_nmul2l ltr_nmul2l. + +lemma ler_nmul2r (x : t) : + x < zero => forall y z, (y * x <= z * x) <=> (z <= y). +proof. by move=> x_lt0 y z /=; rewrite !(@mulrC _ x) ler_nmul2l. qed. + +lemma ltr_nmul2r (x : t) : + x < zero => forall y z, (y * x < z * x) <=> (z < y). +proof. by move=> x_lt0; apply/lerW_nmono/ler_nmul2r. qed. + +hint rewrite lter_nmul2r : ler_nmul2r ltr_nmul2r. + +(* -------------------------------------------------------------------- *) +lemma ler_wpmul2l (x : t) : + zero <= x => forall y z, y <= z => x * y <= x * z. +proof. +rewrite le0r => -[-> y z|/ler_pmul2l/mono2W ? //]. + by rewrite !mul0r lerr. +qed. + +lemma ler_wpmul2r (x : t) : + zero <= x => forall y z, y <= z => y * x <= z * x. +proof. by move=> x_ge0 y z leyz; rewrite !(@mulrC _ x) ler_wpmul2l. qed. + +lemma ler_wnmul2l (x : t) : + x <= zero => forall y z, y <= z => x * z <= x * y. +proof. +by move=> x_le0 y z leyz; rewrite -!(@mulrNN x) ler_wpmul2l ?lter_oppE. +qed. + +lemma ler_wnmul2r (x : t) : + x <= zero => forall y z, y <= z => z * x <= y * x. +proof. +by move=> x_le0 y z leyz; rewrite -!(@mulrNN _ x) ler_wpmul2r ?lter_oppE. +qed. + +(* -------------------------------------------------------------------- *) +lemma ler_pmul (x1 y1 x2 y2 : t): + zero <= x1 => zero <= x2 => x1 <= y1 => x2 <= y2 => x1 * x2 <= y1 * y2. +proof. +move=> x1ge0 x2ge0 le_xy1 le_xy2; have y1ge0 := ler_trans _ _ _ x1ge0 le_xy1. +have le1 := ler_wpmul2r _ x2ge0 _ _ le_xy1. +have le2 := ler_wpmul2l _ y1ge0 _ _ le_xy2. +by apply/(ler_trans _ le1 le2). +qed. + +lemma ltr_pmul (x1 y1 x2 y2 : t): + zero <= x1 => zero <= x2 => x1 < y1 => x2 < y2 => x1 * x2 < y1 * y2. +proof. +move=> x1ge0 x2ge0 lt_xy1 lt_xy2; apply/(@ler_lt_trans (y1 * x2)). + by apply/ler_wpmul2r/ltrW. +by apply/ltr_pmul2l=> //; apply/(ler_lt_trans _ x1ge0). +qed. + +(* -------------------------------------------------------------------- *) +lemma ler_total (x y : t) : (x <= y) \/ (y <= x). +proof. +have := real_axiom y; have := real_axiom x. +case: (zero <= x)=> /= [x_ge0|x_nge0 x_le0]; last first. + case: (zero <= y)=> /=; first by move/(ler_trans _ _ _ x_le0)=> ->. + by move=> _ /(ler_leVge _ _ x_le0). +by case=> [/(ger_leVge _ _ x_ge0) //| /ler_trans ->]. +qed. + +lemma ltr_total (x y : t) : x <> y => (x < y) \/ (y < x). +proof. by rewrite !ltr_def (@eq_sym _ y) => -> /=; apply: ler_total. qed. + +lemma ltrNge (x y : t): (x < y) <=> !(y <= x). +proof. +rewrite ltr_def; have := ler_total x y. +by case: (x <= y)=> //=; rewrite eqr_le => ->. +qed. + +lemma lerNgt (x y : t): (x <= y) <=> !(y < x). +proof. by rewrite ltrNge. qed. + +(* -------------------------------------------------------------------- *) +lemma pmulr_gt0 (x y : t) : zero <= x => zero <= y => + zero < x * y <=> zero < x /\ zero < y. +proof. +move=> x_ge0 y_ge0; split; last by smt(pmulr_rgt0). +smt (pmulr_rgt0 ltrNge ler_anti mul0r ltrr). +qed. + +(* -------------------------------------------------------------------- *) +lemma leVge (x y : t) : (x <= y) \/ (y <= x). +proof. exact ler_total. qed. + +lemma leVgt (x y : t) : (x <= y) \/ (y < x). +proof. by case: (x <= y) => // /ltrNge. qed. + +(* -------------------------------------------------------------------- *) +lemma ltrN10: -oner<:t> < zero. +proof. by rewrite oppr_lt0 ltr01. qed. + +lemma lerN10: -oner<:t> <= zero. +proof. by rewrite oppr_le0 ler01. qed. + +lemma ltr0N1: !(zero<:t> < -oner). +proof. by rewrite ler_gtF // lerN10. qed. + +lemma ler0N1: !(zero<:t> <= -oner). +proof. by rewrite ltr_geF // ltrN10. qed. + +lemma pmulr_rlt0 (x y : t): + zero < x => (x * y < zero) <=> (y < zero). +proof. +by move=> x_gt0; rewrite -oppr_gt0 -mulrN pmulr_rgt0 // oppr_gt0. +qed. + +lemma pmulr_rle0 (x y : t): + zero < x => (x * y <= zero) <=> (y <= zero). +proof. +by move=> x_gt0; rewrite -oppr_ge0 -mulrN pmulr_rge0 // oppr_ge0. +qed. + +lemma pmulr_lgt0 (x y : t): + zero < x => (zero < y * x) <=> (zero < y). +proof. by move=> x_gt0; rewrite mulrC pmulr_rgt0. qed. + +lemma pmulr_lge0 (x y : t): + zero < x => (zero <= y * x) <=> (zero <= y). +proof. by move=> x_gt0; rewrite mulrC pmulr_rge0. qed. + +lemma pmulr_llt0 (x y : t): + zero < x => (y * x < zero) <=> (y < zero). +proof. by move=> x_gt0; rewrite mulrC pmulr_rlt0. qed. + +lemma pmulr_lle0 (x y : t): + zero < x => (y * x <= zero) <=> (y <= zero). +proof. by move=> x_gt0; rewrite mulrC pmulr_rle0. qed. + +lemma nmulr_rgt0 (x y : t): + x < zero => (zero < x * y) <=> (y < zero). +proof. by move=> x_lt0; rewrite -mulrNN pmulr_rgt0 lter_oppE. qed. + +lemma nmulr_rge0 (x y : t): + x < zero => (zero <= x * y) <=> (y <= zero). +proof. by move=> x_lt0; rewrite -mulrNN pmulr_rge0 lter_oppE. qed. + +lemma nmulr_rlt0 (x y : t): + x < zero => (x * y < zero) <=> (zero < y). +proof. by move=> x_lt0; rewrite -mulrNN pmulr_rlt0 lter_oppE. qed. + +lemma nmulr_rle0 (x y : t): + x < zero => (x * y <= zero) <=> (zero <= y). +proof. by move=> x_lt0; rewrite -mulrNN pmulr_rle0 lter_oppE. qed. + +lemma nmulr_lgt0 (x y : t): + x < zero => (zero < y * x) <=> (y < zero). +proof. by move=> x_lt0; rewrite mulrC nmulr_rgt0. qed. + +lemma nmulr_lge0 (x y : t): + x < zero => (zero <= y * x) <=> (y <= zero). +proof. by move=> x_lt0; rewrite mulrC nmulr_rge0. qed. + +lemma nmulr_llt0 (x y : t): + x < zero => (y * x < zero) <=> (zero < y). +proof. by move=> x_lt0; rewrite mulrC nmulr_rlt0. qed. + +lemma nmulr_lle0 (x y : t): + x < zero => (y * x <= zero) <=> (zero <= y). +proof. by move=> x_lt0; rewrite mulrC nmulr_rle0. qed. + +lemma mulr_ge0 (x y : t): + zero <= x => zero <= y => zero <= x * y. +proof. by move=> x_ge0 y_ge0; rewrite -(mulr0 x) ler_wpmul2l. qed. + +lemma mulr_le0 (x y : t): + x <= zero => y <= zero => zero <= x * y. +proof. by move=> x_le0 y_le0; rewrite -(mulr0 x) ler_wnmul2l. qed. + +lemma mulr_ge0_le0 (x y : t): + zero <= x => y <= zero => x * y <= zero. +proof. by move=> x_le0 y_le0; rewrite -(mulr0 x) ler_wpmul2l. qed. + +lemma mulr_le0_ge0 (x y : t): + x <= zero => zero <= y => x * y <= zero. +proof. by move=> x_le0 y_le0; rewrite -(mulr0 x) ler_wnmul2l. qed. + +lemma mulr_gt0 (x y : t): + zero < x => zero < y => zero < x * y. +proof. by move=> x_gt0 y_gt0; rewrite pmulr_rgt0. qed. + +(* -------------------------------------------------------------------- *) +lemma ger_pmull (x y : t) : zero < y => (x * y <= y) <=> (x <= oner). +proof. by move=> hy; rewrite -{2}(mul1r y) ler_pmul2r. qed. + +lemma gtr_pmull (x y : t) : zero < y => (x * y < y) <=> (x < oner). +proof. by move=> hy; rewrite -{2}(mul1r y) ltr_pmul2r. qed. + +lemma ger_pmulr (x y : t) : zero < y => (y * x <= y) <=> (x <= oner). +proof. by move=> hy; rewrite -{2}(mulr1 y) ler_pmul2l. qed. + +lemma gtr_pmulr (x y : t) : zero < y => (y * x < y) <=> (x < oner). +proof. by move=> hy; rewrite -{2}(mulr1 y); rewrite ltr_pmul2l. qed. + +lemma ler_pmull (x y : t) : zero < y => (y <= x * y) <=> (oner <= x). +proof. by move=> hy; rewrite -{1}(mul1r y) ler_pmul2r. qed. + +lemma ltr_pmull (x y : t) : zero < y => (y < x * y) <=>(oner < x). +proof. by move=> hy; rewrite -{1}(mul1r y) ltr_pmul2r. qed. + +lemma ler_pmulr (x y : t) : zero < y => (y <= y * x) <=>(oner <= x). +proof. by move=> hy; rewrite -{1}(mulr1 y) ler_pmul2l. qed. + +lemma ltr_pmulr (x y : t) : zero < y => (y < y * x) <=>(oner < x). +proof. by move=> hy; rewrite -{1}(mulr1 y) ltr_pmul2l. qed. + +lemma ger_nmull (x y : t) : y < zero => (x * y <= y) = (oner <= x). +proof. by move=> hy; rewrite -{2}(mul1r y) ler_nmul2r. qed. + +lemma gtr_nmull (x y : t) : y < zero => (x * y < y) = (oner < x). +proof. by move=> hy; rewrite -{2}(mul1r y) ltr_nmul2r. qed. + +lemma ger_nmulr (x y : t) : y < zero => (y * x <= y) = (oner <= x). +proof. by move=> hy; rewrite -{2}(mulr1 y) ler_nmul2l. qed. + +lemma gtr_nmulr (x y : t) : y < zero => (y * x < y) = (oner < x). +proof. by move=> hy; rewrite -{2}(mulr1 y) ltr_nmul2l. qed. + +lemma ler_nmull (x y : t) : y < zero => (y <= x * y) <=> (x <= oner). +proof. by move=> hy; rewrite -{1}(mul1r y) ler_nmul2r. qed. + +lemma ltr_nmull (x y : t) : y < zero => (y < x * y) <=> (x < oner). +proof. by move=> hy; rewrite -{1}(mul1r y) ltr_nmul2r. qed. + +lemma ler_nmulr (x y : t) : y < zero => (y <= y * x) <=> (x <= oner). +proof. by move=> hy; rewrite -{1}(mulr1 y) ler_nmul2l. qed. + +lemma ltr_nmulr (x y : t) : y < zero => (y < y * x) <=> (x < oner). +proof. by move=> hy; rewrite -{1}(mulr1 y) ltr_nmul2l. qed. + +(* -------------------------------------------------------------------- *) +lemma ler_pemull (x y : t) : zero <= y => oner <= x => y <= x * y. +proof. by move=> hy hx; rewrite -{1}(mul1r y) ler_wpmul2r. qed. + +lemma ler_nemull (x y : t) : y <= zero => oner <= x => x * y <= y. +proof. by move=> hy hx; rewrite -{2}(mul1r y) ler_wnmul2r. qed. + +lemma ler_pemulr (x y : t) : zero <= y => oner <= x => y <= y * x. +proof. by move=> hy hx; rewrite -{1}(mulr1 y) ler_wpmul2l. qed. + +lemma ler_nemulr (x y : t) : y <= zero => oner <= x => y * x <= y. +proof. by move=> hy hx; rewrite -{2}(mulr1 y) ler_wnmul2l. qed. + +lemma ler_pimull (x y : t) : zero <= y => x <= oner => x * y <= y. +proof. by move=> hy hx; rewrite -{2}(mul1r y) ler_wpmul2r. qed. + +lemma ler_nimull (x y : t) : y <= zero => x <= oner => y <= x * y. +proof. by move=> hy hx; rewrite -{1}(mul1r y) ler_wnmul2r. qed. + +lemma ler_pimulr (x y : t) : zero <= y => x <= oner => y * x <= y. +proof. by move=> hy hx; rewrite -{2}(mulr1 y) ler_wpmul2l. qed. + +lemma ler_nimulr (x y : t) : y <= zero => x <= oner => y <= y * x. +proof. by move=> hx hy; rewrite -{1}(mulr1 y) ler_wnmul2l. qed. + +(* -------------------------------------------------------------------- *) +lemma mulr_ile1 (x y : t): + zero <= x => zero <= y => x <= oner => y <= oner => x * y <= oner. +proof. by move=> ????; rewrite (@ler_trans y) ?ler_pimull. qed. + +lemma mulr_ilt1 (x y : t): + zero <= x => zero <= y => x < oner => y < oner => x * y < oner. +proof. by move=> ????; rewrite (@ler_lt_trans y) ?ler_pimull // ?ltrW. qed. + +hint rewrite mulr_ilte1 : mulr_ile1 mulr_ilt1. +hint rewrite mulr_cp1 : mulr_ile1 mulr_ilt1. + +(* -------------------------------------------------------------------- *) +lemma mulr_ege1 (x y : t) : oner <= x => oner <= y => oner <= x * y. +proof. +by move=> le1x le1y; rewrite (@ler_trans y) ?ler_pemull // (ler_trans _ ler01). +qed. + +lemma mulr_egt1 (x y : t) : oner < x => oner < y => oner < x * y. +proof. +by move=> le1x lt1y; rewrite (@ltr_trans y) // ltr_pmull // (ltr_trans _ ltr01). +qed. + +hint rewrite mulr_egte1 : mulr_ege1 mulr_egt1. +hint rewrite mulr_cp1 : mulr_ege1 mulr_egt1. + +(* -------------------------------------------------------------------- *) +lemma invr_gt0 (x : t) : (zero < invr x) <=> (zero < x). +proof. +case: (unit x) => [ux|nux]; last by rewrite invr_out. +by split=> /ltr_pmul2r <-; rewrite mul0r (mulrV, mulVr) ?ltr01. +qed. + +lemma invr_ge0 (x : t) : (zero <= invr x) <=> (zero <= x). +proof. by rewrite !le0r invr_gt0 invr_eq0. qed. + +lemma invr_lt0 (x : t) : (invr x < zero) <=> (x < zero). +proof. by rewrite -oppr_cp0 -invrN invr_gt0 oppr_cp0. qed. + +lemma invr_le0 (x : t) : (invr x <= zero) <=> (x <= zero). +proof. by rewrite -oppr_cp0 -invrN invr_ge0 oppr_cp0. qed. + +(* -------------------------------------------------------------------- *) +lemma divr_ge0 (x y : t) : zero <= x => zero <= y => zero <= x / y. +proof. by move=> x_ge0 y_ge0; rewrite mulr_ge0 ?invr_ge0. qed. + +lemma divr_gt0 (x y : t) : zero < x => zero < y => zero < x / y. +proof. by move=> x_gt0 y_gt0; rewrite pmulr_rgt0 ?invr_gt0. qed. + +(* -------------------------------------------------------------------- *) +lemma ler_pinv : + forall (x y : t), unit x => zero < x => unit y => zero < y => + (invr y <= invr x) <=> (x <= y). +proof. +move=> x y Ux hx Uy hy; rewrite -(ler_pmul2l hx) -(ler_pmul2r hy). +by rewrite !(divrr, mulrVK) // mul1r. +qed. + +lemma ler_ninv : + forall (x y : t), unit x => x < zero => unit y => y < zero => + (invr y <= invr x) <=> (x <= y). +proof. +move=> x y Ux hx Uy hy; rewrite -(ler_nmul2l hx) -(ler_nmul2r hy). +by rewrite !(divrr, mulrVK) // mul1r. +qed. + +lemma ltr_pinv : + forall (x y : t), unit x => zero < x => unit y => zero < y => + (invr y < invr x) <=> (x < y). +proof. +move=> x y Ux hx Uy hy; rewrite -(ltr_pmul2l hx) -(ltr_pmul2r hy). +by rewrite !(divrr, mulrVK) // mul1r. +qed. + +lemma ltr_ninv : + forall (x y : t), unit x => x < zero => unit y => y < zero => + (invr y < invr x) <=> (x < y). +proof. +move=> x y Ux hx Uy hy; rewrite -(ltr_nmul2l hx) -(ltr_nmul2r hy). +by rewrite !(divrr, mulrVK) // mul1r. +qed. + +(* -------------------------------------------------------------------- *) +lemma invr_gt1 (x : t) : unit x => zero < x => (oner < invr x) <=> (x < oner). +proof. by move=> Ux gt0_x; rewrite -{1}invr1 ltr_pinv ?unitr1 ?ltr01. qed. + +lemma invr_ge1 (x : t) : unit x => zero < x => (oner <= invr x) <=> (x <= oner). +proof. by move=> Ux gt0_x; rewrite -{1}invr1 ler_pinv ?unitr1 ?ltr01. qed. + +hint rewrite invr_gte1 : invr_ge1 invr_gt1. +hint rewrite invr_cp1 : invr_ge1 invr_gt1. + +lemma invr_le1 (x : t) : unit x => zero < x => (invr x <= oner) <=> (oner <= x). +proof. by move=> ux hx; rewrite -invr_ge1 ?invr_gt0 ?unitrV // invrK. qed. + +lemma invr_lt1 (x : t) : unit x => zero < x => (invr x < oner) <=> (oner < x). +proof. by move=> ux hx; rewrite -invr_gt1 ?invr_gt0 ?unitrV // invrK. qed. + +hint rewrite invr_lte1 : invr_le1 invr_lt1. +hint rewrite invr_cp1 : invr_le1 invr_lt1. + +(* -------------------------------------------------------------------- *) +lemma expr_ge0 n (x : t) : zero <= x => zero <= exp x n. +proof. +move=> ge0_x; elim/intwlog: n. ++ by move=> n; rewrite exprN invr_ge0. ++ by rewrite expr0 ler01. ++ by move=> n ge0_n ge0_e; rewrite exprS // mulr_ge0. +qed. + +lemma expr_gt0 n (x : t) : zero < x => zero < exp x n. +proof. by rewrite !lt0r expf_eq0 => -[->/=]; apply/expr_ge0. qed. + +hint rewrite expr_gte0 : expr_ge0 expr_gt0. + +lemma exprn_ile1 n (x : t) : 0 <= n => zero <= x <= oner => exp x n <= oner. +proof. +move=> nge0 [xge0 xle1]; elim: n nge0; 1: by rewrite expr0. +by move=> n ge0_n ih; rewrite exprS // mulr_ile1 ?expr_ge0. +qed. + +lemma exprn_ilt1 n (x : t) : + 0 <= n => zero <= x < oner => (exp x n < oner) <=> (n <> 0). +proof. +move=> nge0 [xge0 xlt1]; case: n nge0; 1: by rewrite expr0 ltrr. +move=> n nge0 _; rewrite addz_neq0 //=; elim: n nge0; 1: by rewrite expr1. +by move=> n nge0 ih; rewrite exprS 1:addz_ge0 // mulr_ilt1 ?expr_ge0. +qed. + +hint rewrite exprn_ilte1 : exprn_ile1 exprn_ilt1. +hint rewrite exprn_cp1 : exprn_ile1 exprn_ilt1. + +lemma exprn_ege1 n (x : t) : 0 <= n => oner <= x => oner <= exp x n. +proof. +move=> nge0 xge1; elim: n nge0 => [|n nge0 ih]; 1: by rewrite expr0. +by rewrite exprS // mulr_ege1. +qed. + +lemma exprn_egt1 n (x : t) : 0 <= n => oner < x => (oner < exp x n) <=> (n <> 0). +proof. +move=> nge0 xgt1; case: n nge0 => [|n nge0 _]; 1: by rewrite expr0 ltrr. +elim: n nge0 => [|n ge0n]; 1: by rewrite expr1. +rewrite !addz1_neq0 ?addz_ge0 //= => ih. +by rewrite (@exprS _ (n+1)) 1:addz_ge0 // mulr_egt1. +qed. + +hint rewrite exprn_egte1 : exprn_ege1 exprn_egt1. +hint rewrite exprn_cp1 : exprn_ege1 exprn_egt1. + +lemma ler_iexpr (x : t) n : 0 < n => zero <= x <= oner => exp x n <= x. +proof. +rewrite ltz_def => -[nz_n ge0_n]; case: n ge0_n nz_n => // n ge0_n _ _. +by case=> xge0 xlt1; rewrite exprS // ler_pimulr // exprn_ile1. +qed. + +lemma ltr_iexpr (x : t) n : 0 <= n => zero < x < oner => (exp x n < x <=> 1 < n). +proof. +move=> nge0 [xgt0 xlt1]; case: n nge0 => /= [|n nge0 _]. ++ by rewrite expr0 ltrNge ltrW. +case: n nge0 => /= [|n nge0 _]; first by rewrite expr1 ltrr. +rewrite (@ltz_add2r 1 0 (n+1)) -lez_add1r /= lez_addr nge0 /=. +rewrite (@exprS _ (n+1)) 1:addz_ge0 // gtr_pmulr //. +by rewrite exprn_ilt1 ?(addz_neq0, addz_ge0) // ltrW. +qed. + +hint rewrite lter_iexpr : ler_iexpr ltr_iexpr. +hint rewrite lter_expr : ler_iexpr ltr_iexpr. + +lemma ler_eexpr (x : t) n : 0 < n => oner <= x => x <= exp x n. +proof. +rewrite ltz_def => -[nz_n ge0_n]; case: n ge0_n nz_n => //=. +move=> n ge0_n _ _ ge1_x; rewrite exprS //. +by rewrite ler_pemulr 2:exprn_ege1 // &(@ler_trans oner) ?ler01. +qed. + +lemma ltr_eexpr (x : t) n : 0 <= n => oner < x => (x < exp x n <=> 1 < n). +proof. +move=> ge0_n lt1_x; case: n ge0_n; 1: by rewrite expr0 ltrNge ltrW. +move=> + + _; case=> /= [|n ge0_n _]; first by rewrite expr1 ltrr. +rewrite (@ltz_add2r 1 0 (n+1)) -lez_add1r /= lez_addr ge0_n /=. +rewrite (@exprS _ (n+1)) 1:addz_ge0 // ltr_pmulr 1:&(@ltr_trans oner) //. +by rewrite exprn_egt1 // ?(addz_neq0, addz_ge0). +qed. + +hint rewrite lter_eexpr : ler_eexpr ltr_eexpr. +hint rewrite lter_expr : ler_eexpr ltr_eexpr. + +lemma ler_wiexpn2l (x : t) : zero <= x <= oner => + forall m n, 0 <= n <= m => exp x m <= exp x n. +proof. +move=> [xge0 xle1] m n [ge0_n le_nm]; have ->: m = (m - n) + n by ring. +by rewrite exprD_nneg 1:subz_ge0 // ler_pimull ?(expr_ge0, exprn_ile1) ?subz_ge0. +qed. + +lemma ler_weexpn2l (x : t) : oner <= x => + forall m n, 0 <= m <= n => exp x m <= exp x n. +proof. +move=> ge1_x m n [ge0_m le_mn]; have ->: n = (n - m) + m by ring. +rewrite exprD_nneg 1:subz_ge0 // ler_pemull ?(expr_ge0, exprn_ege1) //. ++ by rewrite (@ler_trans oner). + by rewrite subz_ge0. +qed. + +lemma ler_weexpn2r (x : t) : oner < x => + forall m n, 0 <= m => 0 <= n => exp x m <= exp x n => m <= n. +proof. +move => lt1x m n le0m le0n; rewrite -implybNN -ltrNge -ltzNge ltzE => le_m; apply (ltr_le_trans (exp x (n + 1))). ++ by rewrite exprS //; apply ltr_pmull => //; apply/expr_gt0/(ler_lt_trans oner). +by apply ler_weexpn2l; [apply ltrW|split => //; apply addz_ge0]. +qed. + +lemma ieexprn_weq1 (x : t) n : 0 <= n => zero <= x => + (exp x n = oner) <=> (n = 0 || x = oner). +proof. +case: n => [|n ge0_n _] ge0_x; first by rewrite expr0. +rewrite !addz_neq0 //=; split=> [|->]; last by rewrite expr1z. +case: (x = oner) => [->//|/ltr_total [] hx] /=. ++ by rewrite ltr_eqF // exprn_ilt1 // (addz_ge0, addz_neq0). ++ by rewrite gtr_eqF // exprn_egt1 // (addz_ge0, addz_neq0). +qed. + +lemma ieexprIn (x : t) : zero < x => x <> oner => + forall m n, 0 <= m => 0 <= n => exp x m = exp x n => m = n. +proof. +(* FIXME: wlog *) +move=> gt0_x neq1_x m n; pose P := fun m n => 0 <= m => 0 <= n => + exp x m = exp x n => m = n; rewrite -/(P m n). +have: (forall m n, (m <= n)%Int => P m n) => P m n. ++ move=> ih; case: (lez_total m n); first by apply/ih. + by move/ih=> @/P h *; rewrite -h // eq_sym. +apply=> {m n} m n le_mn ge0_m ge0_n {P}. +have ->: n = m + (n - m) by ring. +rewrite exprD_nneg 2:subz_ge0 // -{1}(mulr1 (exp x m)). +have h/h{h} := mulfI (exp x m) _; first by rewrite expf_eq0 gtr_eqF. +by rewrite eq_sym ieexprn_weq1 1?(subz_ge0, ltrW) //#. +qed. + +lemma ler_pexp n (x y : t) : + 0 <= n => zero <= x <= y => exp x n <= exp y n. +proof. +move=> h; elim/intind: n h x y => [|n ge0_n ih] x y [ge0_x le_xy]. ++ by rewrite !expr0. ++ by rewrite !exprS // ler_pmul // ?expr_ge0 ?ih. +qed. + +lemma ge0_sqr (x : t) : zero <= exp x 2. +proof. +rewrite expr2; case: (zero <= x); first by move=> h; apply/mulr_ge0. +by rewrite lerNgt /= => /ltrW le0_x; apply/mulr_le0. +qed. + +(* -------------------------------------------------------------------- *) +lemma ler_norm_sub (x y : t): + `|x - y| <= `|x| + `|y|. +proof. by rewrite -(@normrN y) ler_norm_add. qed. + +lemma ler_dist_add (z x y : t): + `|x - y| <= `|x - z| + `|z - y|. +proof. +apply/(ler_trans _ _ (:@ler_norm_add (x-z) (z-y))). +by rewrite addrA addrNK lerr. +qed. + +lemma ler_sub_norm_add (x y : t): + `|x| - `|y| <= `|x + y|. +proof. +rewrite -{1}(@addrK y x) lter_sub_addl; + rewrite (ler_trans _ (:@ler_norm_add (x+y) (-y))) //. +by rewrite addrC normrN lerr. +qed. + +lemma ler_sub_dist (x y : t): + `|x| - `|y| <= `|x - y|. +proof. by rewrite -(@normrN y) ler_sub_norm_add. qed. + +lemma ler_dist_dist (x y : t): + `| `|x| - `|y| | <= `|x - y|. +proof. +case: (`|x| <= `|y|); last first. + rewrite -ltrNge=> /ltrW le_yx; + by rewrite ger0_norm ?ler_sub_dist // subr_ge0. +move=> le_xy; rewrite ler0_norm ?subr_le0 //. +by rewrite distrC opprB ler_sub_dist. +qed. + +lemma ler_dist_norm_add (x y : t): + `| `|x| - `|y| | <= `|x + y|. +proof. by rewrite -(@opprK y) normrN ler_dist_dist. qed. + +lemma ler_nnorml (x y : t): y < zero => ! (`|x| <= y). +proof. by move=> y_lt0; rewrite ltr_geF // (ltr_le_trans _ y_lt0) ?normr_ge0. qed. + +lemma ltr_nnorml (x y : t): y <= zero => ! (`|x| < y). +proof. by move=> y_le0; rewrite ler_gtF // (ler_trans _ y_le0) ?normr_ge0. qed. + +lemma eqr_norm_id (x : t): (`|x| = x) <=> (zero <= x). +proof. by rewrite ger0_def. qed. + +lemma eqr_normN (x : t): (`|x| = - x) <=> (x <= zero). +proof. by rewrite ler0_def. qed. + +lemma normE (n : t) : + `|n| = if zero <= n then n else -n. +proof. +move: (real_axiom n); rewrite or_andr => -[le0n|[Nle0n len0]]. ++ by rewrite le0n /= eqr_norm_id. +by rewrite Nle0n /= eqr_normN. +qed. + +(* -------------------------------------------------------------------- *) +lemma ler_norm (x : t) : x <= `|x|. +proof. +case: (zero <= x); first by move/ger0_norm=> ->; apply/lerr. +move/ltrNge=> /ltrW ^h /ler0_norm ->; apply/(ler_trans zero)=> //. +by rewrite oppr_ge0. +qed. + +lemma eqr_norml (x y : t) : (`|x| = y) <=> ((x = y) \/ (x = -y)) /\ (zero <= y). +proof. +split=> [|[]]; last by case=> -> h; rewrite ?normrN ger0_norm. +move=> <-; rewrite normr_ge0 /=; case: (x <= zero) => [|/ltrNge]. + by move/ler0_norm=> ->; rewrite opprK. +by move/gtr0_norm=> ->. +qed. + +(* -------------------------------------------------------------------- *) +lemma ler_norml (x y : t) : (`|x| <= y) <=> (- y <= x <= y). +proof. +have h: forall (z : t), zero <= z => (z <= y) <=> (- y <= z <= y). + move=> z ge0_z; case: (z <= y)=> //= le_zy; apply/(ler_trans zero)=> //. + by rewrite oppr_le0 (ler_trans z). +case: (zero <= x) => [^ge0_x /h|/ltrNge/ltrW ge0_x]; first by rewrite ger0_norm. +rewrite -(opprK x) normrN ler_opp2 andaE andbC ler_oppl h. + by rewrite normr_ge0. by rewrite ger0_norm // oppr_ge0. +qed. + +lemma ltr_normr (x y : t) : (x < `|y|) <=> (x < y) \/ (x < - y). +proof. by rewrite ltrNge ler_norml andaE negb_and -!ltrNge ltr_oppr orbC. qed. + +lemma ltr_norml : forall (x y : t), (`|x| < y) <=> (- y < x < y). +proof. +have h: + (forall (x y : t), zero <= x => (`|x| < y) <=> (- y < x < y)) + => forall (x y : t), (`|x| < y) <=> (- y < x < y). ++ move=> wlog x y; case: (leVge zero x) => [/wlog|hx]; 1: by apply. + rewrite -(opprK x) normrN wlog ?oppr_ge0 //. + by rewrite !ltr_opp2 !andaE andbC opprK. +apply/h=> x y hx; rewrite ger0_norm //; case: (x < y) => //= le_xy. +by rewrite (ltr_le_trans _ _ hx) oppr_lt0 (ler_lt_trans _ hx). +qed. + +lemma ler_normr (x y : t) : (x <= `|y|) <=> (x <= y) \/ (x <= - y). +proof. +by rewrite lerNgt ltr_norml // andaE negb_and !lerNgt orbC ltr_oppl. +qed. + +(* -------------------------------------------------------------------- *) +lemma maxrC (x y : t) : maxr x y = maxr y x. +proof. by rewrite !maxrE lerNgt ler_eqVlt; case: (x = y); case: (x < y). qed. + +lemma maxrA (x y z: t): maxr (maxr x y) z = maxr x (maxr y z). +proof. +rewrite !maxrE. +case (y <= x); case (z <= y); case (z <= x) => // + [/#||/#|/#|]. +- smt(ler_trans). +- smt(ltr_trans ltrNge). +qed. + +lemma maxrl (x y : t) : x <= maxr x y. +proof. by rewrite maxrE; case: (y <= x) => [_|/ltrNge/ltrW]. qed. + +lemma maxrr (x y : t) : y <= maxr x y. +proof. by rewrite maxrC maxrl. qed. + +lemma ler_maxr (x y : t) : x <= y => maxr x y = y. +proof. by rewrite maxrE lerNgt ler_eqVlt => -> /#. qed. + +lemma ler_maxl (x y : t) : y <= x => maxr x y = x. +proof. by rewrite maxrC &(ler_maxr). qed. + +lemma maxr_ub (x y : t) : x <= maxr x y /\ y <= maxr x y. +proof. by rewrite maxrl maxrr. qed. + +lemma ler_maxrP (m n1 n2 : t) : (maxr n1 n2 <= m) <=> (n1 <= m) /\ (n2 <= m). +proof. +split; last by case=> le1 le2; rewrite maxrE; case: (n2 <= n1). +rewrite maxrE; case: (n2 <= n1). +* by move=> le_21 le_n1m; rewrite (ler_trans _ le_21 le_n1m). +* rewrite lerNgt /= => /ltrW le_12 le_n1m. + by rewrite (ler_trans _ le_12 le_n1m). +qed. + +lemma ltr_maxrP (m n1 n2 : t) : (maxr n1 n2 < m) <=> (n1 < m) /\ (n2 < m). +proof. +split; last by case=> le1 le2; rewrite maxrE; case: (n2 <= n1). +rewrite maxrE; case: (n2 <= n1). +* by move=> le_21 lt_n1m; rewrite (ler_lt_trans _ le_21 lt_n1m). +* rewrite lerNgt /= => lt_12 lt_n1m. + by rewrite (ltr_trans _ lt_12 lt_n1m). +qed. + +lemma ler_maxr_trans (x1 x2 y1 y2 : t) : + x1 <= x2 => y1 <= y2 => maxr x1 y1 <= maxr x2 y2. +proof. + by move=> hx hy; rewrite ler_maxrP; case (maxr_ub x2 y2) => hx' hy'; split; + [apply: ler_trans hx' | apply: ler_trans hy']. +qed. + +lemma ler_norm_maxr (x1 x2 : t) : + zero <= x1 => + zero <= x2 => + `| x1 - x2 | <= maxr x1 x2. +proof. + rewrite maxrE normE; case: (x2 <= x1). + + rewrite subr_ge0 => -> /= *; apply ler_subr_addr. + by rewrite opprK ler_addl. + rewrite ler_subr_addr add0r => -> /=. + by rewrite opprB -ler_subr_addr opprK ler_addl. +qed. + +(* -------------------------------------------------------------------- *) +lemma minrC (x y : t) : minr x y = minr y x. +proof. by rewrite !minrE lerNgt ler_eqVlt; case: (y = x); case: (y < x). qed. + +lemma minrA (x y z : t) : minr (minr x y) z = minr x (minr y z). +proof. +rewrite !minrE. +case (x <= y); case (y <= z); case (x <= z) => // + [/#||/#|/#|]. +- smt(ler_trans). +- smt(ltr_trans ltrNge). +qed. + +lemma minrl (x y : t) : minr x y <= x. +proof. by rewrite minrE; case: (x <= y) => [_|/ltrNge/ltrW]. qed. + +lemma minrr (x y : t) : minr x y <= y. +proof. by rewrite minrC minrl. qed. + +lemma ler_minl (x y : t) : x <= y => minr x y = x. +proof. by rewrite minrE lerNgt => ->. qed. + +lemma ler_minr (x y : t) : y <= x => minr x y = y. +proof. by rewrite minrC &(ler_minl). qed. + +lemma minr_lb (x y : t) : minr x y <= x /\ minr x y <= y. +proof. by rewrite minrl minrr. qed. + +end section. + +(* ==================================================================== *) +(* Real-closed field: a [tcrealdomain] where every nonzero element is *) +(* invertible (the field axiom). Mirrors *) +(* [theories/algebra/Number.ec:RealField]. We extend [tcrealdomain] *) +(* (single parent) and add the field axiom locally rather than *) +(* multi-inherit from [tcrealdomain & field]: under multi-parent *) +(* inheritance, both parent paths reach [comring] / [idomain] *) +(* without renamings, leaving [invr]'s parent-DAG witness ambiguous *) +(* across applications and breaking proof terms downstream. *) +(* ==================================================================== *) +type class tcrealfield <: tcrealdomain & field = {}. + +(* -------------------------------------------------------------------- *) +section. +declare type t <: tcrealfield. + +(* -------------------------------------------------------------------- *) +lemma lef_pinv (x y : t) : + zero < x => zero < y => (invr y <= invr x) <=> (x <= y). +proof. by move=> hx hy; apply/ler_pinv => //; apply/unitfP/gtr_eqF. qed. + +lemma lef_ninv (x y : t) : + x < zero => y < zero => (invr y <= invr x) <=> (x <= y). +proof. by move=> hx hy; apply/ler_ninv => //; apply/unitfP/ltr_eqF. qed. + +lemma ltf_pinv (x y : t) : + zero < x => zero < y => (invr y < invr x) <=> (x < y). +proof. by move=> hx hy; apply/ltr_pinv => //; apply/unitfP/gtr_eqF. qed. + +lemma ltf_ninv (x y : t) : + x < zero => y < zero => (invr y < invr x) <=> (x < y). +proof. by move=> hx hy; apply/ltr_ninv => //; apply/unitfP/ltr_eqF. qed. + +(* -------------------------------------------------------------------- *) +lemma ler_pdivl_mulr (z x y : t) : + zero < z => (x <= y / z) <=> (x * z <= y). +proof. by move=> z_gt0; rewrite -(@ler_pmul2r z) // mulrVK ?unitfP ?gtr_eqF. qed. + +lemma ltr_pdivl_mulr (z x y : t) : + zero < z => (x < y / z) <=> (x * z < y). +proof. by move=> z_gt0; rewrite -(@ltr_pmul2r z) // mulrVK ?unitfP ?gtr_eqF. qed. + +hint rewrite lter_pdivl_mulr : ler_pdivl_mulr ltr_pdivl_mulr. + +(* -------------------------------------------------------------------- *) +lemma ler_pdivr_mulr (z x y : t) : + zero < z => (y / z <= x) <=> (y <= x * z). +proof. by move=> z_gt0; rewrite -(@ler_pmul2r z) // mulrVK ?unitfP ?gtr_eqF. qed. + +lemma ltr_pdivr_mulr (z x y : t) : + zero < z => (y / z < x) <=> (y < x * z). +proof. by move=> z_gt0; rewrite -(@ltr_pmul2r z) // mulrVK ?unitfP ?gtr_eqF. qed. + +hint rewrite lter_pdivr_mulr : ler_pdivr_mulr ltr_pdivr_mulr. + +(* -------------------------------------------------------------------- *) +lemma ler_pdivl_mull (z x y : t) : + zero < z => (x <= invr z * y) <=> (z * x <= y). +proof. by move=> z_gt0; rewrite mulrC ler_pdivl_mulr ?(@mulrC z). qed. + +lemma ltr_pdivl_mull (z x y : t) : + zero < z => (x < invr z * y) <=> (z * x < y). +proof. by move=> z_gt0; rewrite mulrC ltr_pdivl_mulr ?(@mulrC z). qed. + +hint rewrite lter_pdivl_mull : ler_pdivl_mull ltr_pdivl_mull. + +(* -------------------------------------------------------------------- *) +lemma ler_pdivr_mull (z x y : t) : + zero < z => (invr z * y <= x) <=> (y <= z * x). +proof. by move=> z_gt0; rewrite mulrC ler_pdivr_mulr ?(@mulrC z). qed. + +lemma ltr_pdivr_mull (z x y : t) : + zero < z => (invr z * y < x) <=> (y < z * x). +proof. by move=> z_gt0; rewrite mulrC ltr_pdivr_mulr ?(@mulrC z). qed. + +hint rewrite lter_pdivr_mull : ler_pdivr_mull ltr_pdivr_mull. + +(* -------------------------------------------------------------------- *) +lemma ler_ndivl_mulr (z x y : t) : + z < zero => (x <= y / z) <=> (y <= x * z). +proof. by move=> z_lt0; rewrite -(@ler_nmul2r z) // mulrVK ?unitfP ?ltr_eqF. qed. + +lemma ltr_ndivl_mulr (z x y : t) : + z < zero => (x < y / z) <=> (y < x * z). +proof. by move=> z_lt0; rewrite -(@ltr_nmul2r z) // mulrVK ?unitfP ?ltr_eqF. qed. + +hint rewrite lter_ndivl_mulr : ler_ndivl_mulr ltr_ndivl_mulr. + +(* -------------------------------------------------------------------- *) +lemma ler_ndivr_mulr (z x y : t) : + z < zero => (y / z <= x) <=> (x * z <= y). +proof. by move=> z_lt0; rewrite -(@ler_nmul2r z) // mulrVK ?unitfP ?ltr_eqF. qed. + +lemma ltr_ndivr_mulr (z x y : t) : + z < zero => (y / z < x) <=> (x * z < y). +proof. by move=> z_lt0; rewrite -(@ltr_nmul2r z) // mulrVK ?unitfP ?ltr_eqF. qed. + +hint rewrite lter_ndivr_mulr : ler_ndivr_mulr ltr_ndivr_mulr. + +(* -------------------------------------------------------------------- *) +lemma ler_ndivl_mull (z x y : t) : + z < zero => (x <= invr z * y) <=> (y <= z * x). +proof. by move=> z_lt0; rewrite mulrC ler_ndivl_mulr ?(@mulrC z). qed. + +lemma ltr_ndivl_mull (z x y : t) : + z < zero => (x < invr z * y) <=> (y < z * x). +proof. by move=> z_lt0; rewrite mulrC ltr_ndivl_mulr ?(@mulrC z). qed. + +hint rewrite lter_ndivl_mull : ler_ndivl_mull ltr_ndivl_mull. + +(* -------------------------------------------------------------------- *) +lemma ler_ndivr_mull (z x y : t) : + z < zero => (invr z * y <= x) <=> (z * x <= y). +proof. by move=> z_lt0; rewrite mulrC ler_ndivr_mulr ?(@mulrC z). qed. + +lemma ltr_ndivr_mull (z x y : t) : + z < zero => (invr z * y < x) <=> (z * x < y). +proof. by move=> z_lt0; rewrite mulrC ltr_ndivr_mulr ?(@mulrC z). qed. + +hint rewrite lter_ndivr_mull : ler_ndivr_mull ltr_ndivr_mull. + +end section. + +(* ==================================================================== *) +(* Canonical [int] instance for [tcrealdomain]. Mirrors *) +(* [theories/algebra/Number.ec]'s int specialisation. *) +(* ==================================================================== *) +op int_norm = CoreInt.absz. +op int_le = CoreInt.le. +op int_lt = CoreInt.lt. +op int_min = Int.min. +op int_max = Int.max. + +instance tcrealdomain with int + op "`|_|" = int_norm + op (<=) = int_le + op (<) = int_lt + op minr = int_min + op maxr = int_max + + proof ler_norm_add by smt() + proof addr_gt0 by smt() + proof norm_eq0 by smt() + proof ger_leVge by smt() + proof normrM by smt() + proof ler_def by smt() + proof ltr_def by smt() + proof real_axiom by smt() + proof minrE by smt() + proof maxrE by smt(). diff --git a/examples/tcalgebra/TcPoly.ec b/examples/tcalgebra/TcPoly.ec new file mode 100644 index 0000000000..e378658654 --- /dev/null +++ b/examples/tcalgebra/TcPoly.ec @@ -0,0 +1,986 @@ +(* -------------------------------------------------------------------- *) +require import AllCore Finite Distr DList List IntMin StdBigop StdOrder. +require Subtype. +require import TcMonoid TcRing TcBigop TcBigalg TcInt. +(*---*) import Bigint IntOrder. + +(* ==================================================================== *) +(* Univariate polynomials over a [comring] coefficient algebra. Mirrors *) +(* [theories/algebra/Poly.ec:PolyComRing] but as a section over [c] *) +(* with TC instances accumulating: once [poly : addgroup] is registered *) +(* in Phase 3, every [bigA] / [bigZModule] lemma applies to polynomial *) +(* sums; once [poly : comring] in Phase 5, every [bigA]/[bigM] lemma *) +(* in TcBigalg applies. No "BigPoly" clone needed. *) +(* ==================================================================== *) + +section. +declare type c <: comring. + +(* -------------------------------------------------------------------- *) +(* prepoly = sequence-of-coeffs predicate; poly = subtype thereof *) +(* -------------------------------------------------------------------- *) +type prepoly = int -> c. + +op ispoly (p : prepoly) = + (forall i, i < 0 => p i = zero<:c>) + /\ (exists d, forall i, d < i => p i = zero<:c>). + +subtype poly = { p : prepoly | ispoly p } + rename "to_poly", "of_poly". + +realize inhabited. +proof. by exists (fun _ => zero<:c>). qed. + +(* -------------------------------------------------------------------- *) +op "_.[_]" (p : poly) (i : int) = (of_poly p) i. + +lemma lt0_coeff p i : i < 0 => p.[i] = zero<:c>. +proof. +by move=> lt0_i; rewrite /"_.[_]"; case: (of_polyP p) => /(_ _ lt0_i). +qed. + +(* -------------------------------------------------------------------- *) +(* Degree machinery *) +(* -------------------------------------------------------------------- *) +op deg (p : poly) = + argmin idfun (fun i => forall j, i <= j => p.[j] = zero<:c>). + +lemma degP p i : + 0 < i + => p.[i-1] <> zero<:c> + => (forall j, i <= j => p.[j] = zero<:c>) + => deg p = i. +proof. +move=> ge0_i nz_p_iB1 degi @/deg; apply: argmin_eq => /=. +- by apply/ltrW. - by apply: degi. +move=> j [ge0_j lt_ji]; rewrite negb_forall /=. +by exists (i-1); apply/negP => /(_ _); first by move=> /#. +qed. + +lemma deg_leP p i : 0 <= i => + (forall j, i <= j => p.[j] = zero<:c>) => deg p <= i. +proof. +move=> ge0_i; apply: contraLR; rewrite lerNgt /= => lei. +by have @{1}/deg /argmin_min /=: 0 <= i < deg p by done. +qed. + +lemma gedeg_coeff (p : poly) (i : int) : deg p <= i => p.[i] = zero<:c>. +proof. +move=> le_p_i; pose P p i := forall j, i <= j => p.[j] = zero<:c>. +case: (of_polyP p) => [_ [d hd]]; move: (argminP idfun (P p)). +move/(_ (max (d+1) 0) _ _) => /=; first exact: maxrr. +- by move=> j le_d_j; apply: hd => /#. +by apply; apply: le_p_i. +qed. + +lemma ge0_deg p : 0 <= deg p. +proof. rewrite /deg &(ge0_argmin). qed. + +(* -------------------------------------------------------------------- *) +abbrev lc (p : poly) = p.[deg p - 1]. + +(* -------------------------------------------------------------------- *) +(* prepoly-level constructors *) +(* -------------------------------------------------------------------- *) +op prepolyC (a : c ) : prepoly = fun i => if i = 0 then a else zero<:c>. +op prepolyXn (k : int ) : prepoly = fun i => if 0 <= k /\ i = k then oner<:c> else zero<:c>. +op prepolyD (p q : poly) : prepoly = fun i => p.[i] + q.[i]. +op prepolyN (p : poly) : prepoly = fun i => - p.[i]. + +op prepolyM (p q : poly) : prepoly = fun k => + bigiA<:c> predT (fun i => p.[i] * q.[k-i]) 0 (k+1). + +op prepolyZ (z : c) (p : poly) : prepoly = fun k => + z * p.[k]. + +(* -------------------------------------------------------------------- *) +(* ispoly closure *) +(* -------------------------------------------------------------------- *) +lemma ispolyC (a : c) : ispoly (prepolyC a). +proof. +split=> @/prepolyC [c' ?|]; first by rewrite ltr_eqF. +by exists 0 => c' gt1_c'; rewrite gtr_eqF. +qed. + +lemma ispolyXn (k : int) : ispoly (prepolyXn k). +proof. +split=> @/prepolyXn [c' lt0_c|]. ++ by case: (0 <= k) => //= ge0_k; rewrite ltr_eqF //#. ++ by exists k => c' gt1_c'; rewrite gtr_eqF. +qed. + +lemma ispolyN (p : poly) : ispoly (prepolyN p). +proof. +split=> @/prepolyN [c' lt0_c|]; first by rewrite oppr_eq0 lt0_coeff. +by exists (deg p) => c' => /ltrW /gedeg_coeff ->; rewrite oppr0. +qed. + +lemma ispolyD (p q : poly) : ispoly (prepolyD p q). +proof. +split=> @/prepolyD [c' lt0_c|]; first by rewrite !lt0_coeff // addr0. +by exists (1 + max (deg p) (deg q)) => c' le; rewrite !gedeg_coeff ?addr0 //#. +qed. + +lemma ispolyM (p q : poly) : ispoly (prepolyM p q). +proof. +split => @/prepolyM [c' lt0_c|]; 1: by rewrite big_geq //#. +exists (deg p + deg q + 1) => c' ltc; rewrite big_seq big1 //= => i. +rewrite mem_range => -[gt0_i lt_ic]; case: (p.[i] = zero<:c>). +- by move=> ->; rewrite mul0r. +move/(contra _ _ (gedeg_coeff p i)); rewrite lerNgt /= => lt_ip. +by rewrite mulrC gedeg_coeff ?mul0r //#. +qed. + +lemma ispolyZ z p : ispoly (prepolyZ z p). +proof. +split => @/prepolyZ [c' lt0_c|]; 1: by rewrite lt0_coeff //mulr0. +by exists (deg p + 1) => c' gtc; rewrite gedeg_coeff ?mulr0 //#. +qed. + +lemma poly_eqP (p q : poly) : p = q <=> (forall i, 0 <= i => p.[i] = q.[i]). +proof. +split=> [->//|eq_coeff]; apply/of_poly_inj/fun_ext => i. +case: (i < 0) => [lt0_i|/lerNgt /=]; last by apply: eq_coeff. +by rewrite -/"_.[_]" !lt0_coeff. +qed. + +(* -------------------------------------------------------------------- *) +(* poly-level constructors *) +(* -------------------------------------------------------------------- *) +op polyC a = to_polyd (prepolyC a). +op polyXn k = to_polyd (prepolyXn k). +op polyN p = to_polyd (prepolyN p). +op polyD p q = to_polyd (prepolyD p q). +op polyM p q = to_polyd (prepolyM p q). +op polyZ z p = to_polyd (prepolyZ z p). + +abbrev poly0 : poly = polyC zero<:c>. +abbrev poly1 : poly = polyC oner<:c>. +abbrev polyX : poly = polyXn 1. +abbrev X : poly = polyXn 1. +abbrev ( + ) (p q : poly) : poly = polyD p q. +abbrev [ - ] (p : poly) : poly = polyN p. +abbrev ( * ) (p q : poly) : poly = polyM p q. +abbrev ( ** ) z (p : poly) : poly = polyZ z p. + +abbrev ( - ) (p q : poly) : poly = p + (-q). + +(* -------------------------------------------------------------------- *) +(* Coefficient formulas *) +(* -------------------------------------------------------------------- *) +lemma coeffE p k : ispoly p => (to_polyd p).[k] = p k. +proof. by move=> ?; rewrite /"_.[_]" to_polydK. qed. + +lemma polyCE a k : (polyC a).[k] = if k = 0 then a else zero<:c>. +proof. by rewrite coeffE 1:ispolyC. qed. + +lemma polyXE k : X.[k] = if k = 1 then oner<:c> else zero<:c>. +proof. by rewrite coeffE 1:ispolyXn. qed. + +lemma poly0E k : poly0.[k] = zero<:c>. +proof. by rewrite polyCE if_same. qed. + +lemma polyNE p k : (-p).[k] = - p.[k]. +proof. by rewrite coeffE 1:ispolyN. qed. + +lemma polyDE p q k : (p + q).[k] = p.[k] + q.[k]. +proof. by rewrite coeffE 1:ispolyD. qed. + +lemma polyME p q k : (p * q).[k] = + bigiA<:c> predT (fun i => p.[i] * q.[k-i]) 0 (k+1). +proof. by rewrite coeffE 1:ispolyM. qed. + +lemma polyMXE p k : (p * X).[k] = p.[k-1]. +proof. +case: (k < 0) => [lt0_k|]; first by rewrite !lt0_coeff //#. +rewrite ltrNge => /= ge0_k; rewrite polyME; move: ge0_k. +rewrite ler_eqVlt => -[<-|gt0_k] /=. +- by rewrite big_int1 /= polyXE /= mulr0 lt0_coeff. +rewrite (@bigD1<:c, int> _ _ (k-1)) ?mem_range 1:/# 1:range_uniq /=. +rewrite opprB addrCA /= polyXE /= mulr1 big1 // ?addr0 //. +move=> i @/predC1 nei /=; rewrite polyXE. +case: (k - i = 1) => [/#|_ /=]; first by rewrite mulr0. +qed. + +lemma polyZE z p k : (z ** p).[k] = z * p.[k]. +proof. by rewrite coeffE 1:ispolyZ. qed. + +hint rewrite coeffpE : poly0E polyDE polyNE polyME polyZE. + +(* -------------------------------------------------------------------- *) +(* polyC properties *) +(* -------------------------------------------------------------------- *) +lemma polyCN (a : c) : polyC (- a) = - (polyC a). +proof. +apply/poly_eqP=> i ge0_i; rewrite !(coeffpE, polyCE). +by case: (i = 0) => // _; rewrite oppr0. +qed. + +lemma polyCD (a1 a2 : c) : polyC (a1 + a2) = polyC a1 + polyC a2. +proof. +apply/poly_eqP=> i ge0_i; rewrite !(coeffpE, polyCE). +by case: (i = 0) => // _; rewrite addr0. +qed. + +lemma polyCM (a1 a2 : c) : polyC (a1 * a2) = polyC a1 * polyC a2. +proof. +apply/poly_eqP=> i ge0_i; rewrite !(coeffpE, polyCE). +case: (i = 0) => [->|ne0_i]; first by rewrite big_int1 /= !polyCE. +rewrite big_seq big1 ?addr0 //= => j /mem_range rg_j. +rewrite !polyCE; case: (j = 0) => [->>/=|]; last by rewrite mul0r. +by rewrite ne0_i /= mulr0. +qed. + +(* -------------------------------------------------------------------- *) +(* ZModule axioms on poly. Mirrors original [clone Ring.ZModule as *) +(* ZPoly] but as standalone lemmas; will feed into the [addgroup] *) +(* instance in Phase 3. *) +(* -------------------------------------------------------------------- *) +lemma polyD_addrA (p q r : poly) : p + (q + r) = (p + q) + r. +proof. by apply/poly_eqP=> i ge0_i; rewrite !coeffpE addrA. qed. + +lemma polyD_addrC (p q : poly) : p + q = q + p. +proof. by apply/poly_eqP=> i ge0_i; rewrite !coeffpE addrC. qed. + +lemma polyD_add0r (p : poly) : poly0 + p = p. +proof. by apply/poly_eqP=> i ge0_i; rewrite !coeffpE add0r. qed. + +lemma polyD_addNr (p : poly) : (-p) + p = poly0. +proof. by apply/poly_eqP=> i ge0_i; rewrite !coeffpE addNr. qed. + +(* -------------------------------------------------------------------- *) +(* Scaling lemmas *) +(* -------------------------------------------------------------------- *) +lemma scale0p p : zero<:c> ** p = poly0. +proof. by apply/poly_eqP=> i ge0_i; rewrite !coeffpE mul0r. qed. + +lemma scalep0 a : a ** poly0 = poly0. +proof. by apply/poly_eqP=> i ge0_i; rewrite !coeffpE mulr0. qed. + +lemma scale1p p : oner<:c> ** p = p. +proof. by apply/poly_eqP=> i ge0_i; rewrite !coeffpE mul1r. qed. + +lemma scalep1 (a : c) : a ** poly1 = polyC a. +proof. +apply/poly_eqP=> i ge0_i; rewrite !coeffpE !polyCE. +by case: (i = 0) => _; [rewrite mulr1|rewrite mulr0]. +qed. + +lemma scaleNp (a : c) p : (-a) ** p = - (a ** p). +proof. by apply/poly_eqP=> i ge0_i; rewrite !coeffpE mulNr. qed. + +lemma scalepN (a : c) p : a ** (-p) = - (a ** p). +proof. by apply/poly_eqP=> i ge0_i; rewrite !coeffpE mulrN. qed. + +lemma scalepA (a1 a2 : c) p : a1 ** (a2 ** p) = (a1 * a2) ** p. +proof. by apply/poly_eqP=> i ge0_i; rewrite !coeffpE mulrA. qed. + +lemma scalepDr (a : c) p q : a ** (p + q) = (a ** p) + (a ** q). +proof. by apply/poly_eqP=> i ge0_i; rewrite !coeffpE mulrDr. qed. + +lemma scalepBr (a : c) p q : a ** (p - q) = (a ** p) - (a ** q). +proof. by rewrite scalepDr scalepN. qed. + +lemma scalepDl (a1 a2 : c) p : (a1 + a2) ** p = (a1 ** p) + (a2 ** p). +proof. by apply/poly_eqP=> i ge0_i; rewrite !coeffpE mulrDl. qed. + +lemma scalepBl (a1 a2 : c) p : (a1 - a2) ** p = (a1 ** p) - (a2 ** p). +proof. by rewrite scalepDl scaleNp. qed. + +lemma scalepE (a : c) p : a ** p = polyC a * p. +proof. +apply/poly_eqP=> i ge0_i; rewrite !coeffpE /=. +rewrite big_int_recl //= polyCE /=. +rewrite big_seq big1 ?addr0 //= => j /mem_range. +by case=> ge0_j _; rewrite polyCE addz1_neq0 //= mul0r. +qed. + +(* -------------------------------------------------------------------- *) +(* Multiplication: extended coefficient formulas, then the ComRing *) +(* axioms (associativity / commutativity / unit / distributivity). *) +(* Mirrors original [Poly.ec] lines 418-498. *) +(* -------------------------------------------------------------------- *) +lemma polyMEw M p q k : k <= M => + (p * q).[k] = bigiA<:c> predT (fun i => p.[i] * q.[k-i]) 0 (M+1). +proof. +move=> le_kM; case: (k < 0) => [lt0_k|/lerNgt ge0_k]. ++ rewrite lt0_coeff // big_seq big1 //= => i. + by case/mem_range=> [ge0_i lt_iM]; rewrite (lt0_coeff q) ?mulr0 //#. +rewrite (@big_cat_int (k+1)) 1,2:/# -polyME. +rewrite big_seq big1 2:addr0 //= => i /mem_range. +by case=> [lt_ki lt_iM]; rewrite (lt0_coeff q) ?mulr0 //#. +qed. + +lemma polyM_mulrC : commutative polyM. +proof. +move=> p q; apply: poly_eqP => k ge0_k; rewrite !polyME. +pose F j := k - j; rewrite (@big_reindex _ _ F F) 1:/#. +rewrite predT_comp /(\o) /=; pose s := map _ _. +apply: (eq_trans _ _ _ (eq_big_perm _ _ _ (range 0 (k+1)) _)). ++ rewrite uniq_perm_eq 2:&(range_uniq) /s. + * rewrite map_inj_in_uniq 2:&(range_uniq) => x y. + by rewrite !mem_range /F /#. + * move=> x; split => [/mapP[y []]|]; 1: by rewrite !mem_range /#. + rewrite !mem_range => *; apply/mapP; exists (F x). + by rewrite !mem_range /F /#. ++ by apply: eq_bigr => /= i _ @/F; rewrite mulrC /#. +qed. + +lemma polyMEwr M p q k : k <= M => + (p * q).[k] = bigiA<:c> predT (fun i => p.[k-i] * q.[i]) 0 (M+1). +proof. +rewrite -{1}polyM_mulrC => /polyMEw ->; apply: eq_bigr. +by move=> i _ /=; rewrite mulrC. +qed. + +lemma polyMEr p q k : + (p * q).[k] = bigiA<:c> predT (fun i => p.[k-i] * q.[i]) 0 (k+1). +proof. by rewrite (@polyMEwr k). qed. + +lemma polyM_mulrA : associative polyM. +proof. +move=> p q r; apply: poly_eqP => k ge0_k. +have ->: (p * (q * r)).[k] = + bigiA<:c> predT (fun i => + bigiA<:c> predT (fun j => p.[i] * q.[k - i - j] * r.[j]) 0 (k+1) + ) 0 (k+1). ++ rewrite polyME !big_seq &(eq_bigr) => /= i. + case/mem_range => g0_i lt_i_Sk; rewrite (@polyMEwr k) 1:/#. + by rewrite mulr_sumr &(eq_bigr) => /= j _; rewrite mulrA. +have ->: ((p * q) * r).[k] = + bigiA<:c> predT (fun i => + bigiA<:c> predT (fun j => p.[j] * q.[k - i - j] * r.[i]) 0 (k+1) + ) 0 (k+1). ++ rewrite polyMEr !big_seq &(eq_bigr) => /= i. + case/mem_range => ge0_i lt_i_Sk; rewrite (@polyMEw k) 1:/#. + by rewrite mulr_suml &(eq_bigr). +rewrite exchange_big &(eq_bigr) => /= i _. +by rewrite &(eq_bigr) => /= j _ /#. +qed. + +lemma polyM_mul1r : left_id poly1 polyM. +proof. +move=> p; apply: poly_eqP => i ge0_i. +rewrite polyME big_int_recl //= polyCE /= mul1r. +rewrite big_seq big1 -1:?addr0 //=. +move=> j; rewrite mem_range=> -[ge0_j _]; rewrite polyCE. +by rewrite addz1_neq0 //= mul0r. +qed. + +lemma polyM_mul0r p : poly0 * p = poly0. +proof. +apply/poly_eqP=> i _; rewrite poly0E polyME. +by rewrite big1 //= => j _; rewrite poly0E mul0r. +qed. + +lemma polyM_mulrDl : left_distributive polyM polyD. +proof. +move=> p q r; apply: poly_eqP => i ge0_i; rewrite !(polyME, polyDE). +by rewrite -big_split &(eq_bigr) => /= j _; rewrite polyDE mulrDl. +qed. + +lemma polyM_oner_neq0 : poly1 <> poly0. +proof. by apply/negP => /poly_eqP /(_ 0); rewrite !polyCE /= oner_neq0<:c>. qed. + +end section. + +(* -------------------------------------------------------------------- *) +(* Wrappers needed by [instance]: its [op X = name] clause requires a *) +(* qualified ident on the rhs (not an [abbrev]). *) +(* -------------------------------------------------------------------- *) +op poly_zero ['c <: comring] : 'c poly = polyC zero<:'c>. +op poly_one ['c <: comring] : 'c poly = polyC oner<:'c>. + +(* ==================================================================== *) +(* Phase 3: register [poly] as an [addgroup] over a [comring] *) +(* coefficient. Once this lands, every [bigA] / [bigZModule] lemma *) +(* polymorphic over [addmonoid] applies at carrier ['c poly]. *) +(* ==================================================================== *) +instance addgroup with ['c <: comring] ('c poly) + op idm = poly_zero<:'c> + op (+) = polyD<:'c> + op [-] = polyN<:'c> + + proof addmA by apply polyD_addrA + proof addmC by apply polyD_addrC + proof add0m by (move=> p; rewrite -/(poly_zero<:'c>); apply polyD_add0r) + proof addrN by (move=> p; rewrite polyD_addrC -/(poly_zero<:'c>); apply polyD_addNr). + +(* ==================================================================== *) +(* Phase 5: register [poly] as a [comring] over a [comring] coefficient.*) +(* Mirrors [Ring.ec:ComRingDflInv]: when no structural inverse is *) +(* available (here, because the structural "constant with invertible *) +(* coefficient" characterisation only holds when [c] has no zero *) +(* divisors, i.e. [c : idomain]), use [choiceb] to pick a left inverse *) +(* if any exists, fall back to the element itself otherwise. The three *) +(* obligations [mulVr] / [unitP] / [unitout] discharge from [choicebP] *) +(* and [choiceb_dfl] alone — no ring axioms needed. *) +(* ==================================================================== *) +op poly_unit ['c <: comring] (p : 'c poly) : bool = + exists q, polyM q p = poly_one<:'c>. + +op poly_invr ['c <: comring] (p : 'c poly) : 'c poly = + choiceb (fun q => polyM q p = poly_one<:'c>) p. + +instance comring with ['c <: comring] ('c poly) + op idm = poly_zero<:'c> + op (+) = polyD<:'c> + op [-] = polyN<:'c> + op oner = poly_one<:'c> + op ( * ) = polyM<:'c> + op invr = poly_invr<:'c> + op unit = poly_unit<:'c> + + proof addmA by apply polyD_addrA + proof addmC by apply polyD_addrC + proof add0m by (move=> p; rewrite -/(poly_zero<:'c>); apply polyD_add0r) + proof addrN by (move=> p; rewrite polyD_addrC -/(poly_zero<:'c>); apply polyD_addNr) + proof oner_neq0 by (rewrite -/(poly_one<:'c>) -/(poly_zero<:'c>); apply polyM_oner_neq0) + proof mulrA by apply polyM_mulrA + proof mulrC by apply polyM_mulrC + proof mul1r by (move=> p; rewrite -/(poly_one<:'c>); apply polyM_mul1r) + proof mulrDl by apply polyM_mulrDl + proof mulVr by (move=> p hu; rewrite /poly_invr<:'c>; + have := choicebP (fun q => polyM q p = poly_one<:'c>) p hu; + by rewrite /=) + proof unitP by (move=> p q heq; rewrite /poly_unit<:'c>; by exists q) + proof unitout by (move=> p; rewrite /poly_unit<:'c> /poly_invr<:'c> negb_exists => hne; + by apply choiceb_dfl => q; apply hne). + +(* ==================================================================== *) +(* Phase 6: higher-level theory of polynomials over a [comring] *) +(* coefficient. Mirrors [theories/algebra/Poly.ec] from [degC] *) +(* (line 296) onwards: degree arithmetic, multiplicative degree, *) +(* X^i / polyXn, polysumE / polyE / polywE, peval, polyL constructor. *) +(* ==================================================================== *) +section. +declare type c <: comring. + +(* -------------------------------------------------------------------- *) +(* Degree of constants, leading coefficient, [poly0]/[poly1] degrees. *) +(* -------------------------------------------------------------------- *) +lemma degC (a : c) : deg (polyC a) = if a = zero<:c> then 0 else 1. +proof. +case: (a = zero<:c>) => [->|nz_a]; last first. +- apply: degP => //=; first by rewrite polyCE. + by move=> i ge1_i; rewrite polyCE gtr_eqF //#. +rewrite /deg; apply: argmin_eq => //=. +- by move=> j _; rewrite poly0E. +- by move=> j; apply: contraL => _ /#. +qed. + +lemma degC_le (a : c) : deg (polyC a) <= 1. +proof. by rewrite degC; case: (a = zero<:c>). qed. + +lemma lcC (a : c) : lc (polyC a) = a. +proof. by rewrite polyCE degC; case: (a = zero<:c>) => [->|]. qed. + +lemma lc0 : lc poly0<:c> = zero<:c>. +proof. by apply: lcC. qed. + +lemma lc1 : lc poly1<:c> = oner<:c>. +proof. by apply: lcC. qed. + +lemma deg0 : deg poly0<:c> = 0. +proof. by rewrite degC. qed. + +lemma deg1 : deg poly1<:c> = 1. +proof. +apply: degP => //=; first by rewrite polyCE /= oner_neq0. +by move=> i ge1_i; rewrite polyCE gtr_eqF //#. +qed. + +lemma deg_eq0 (p : c poly) : (deg p = 0) <=> (p = poly0). +proof. +split=> [z_degp|->]; last by rewrite deg0. +apply/poly_eqP=> i ge0_i; rewrite poly0E. +by apply/gedeg_coeff; rewrite z_degp. +qed. + +lemma degX : deg X<:c> = 2. +proof. +apply/degP=> //=; first by rewrite polyXE /= oner_neq0. +by move=> i ge2_i; rewrite polyXE gtr_eqF //#. +qed. + +lemma nz_polyX : X<:c> <> poly0. +proof. by rewrite -deg_eq0 degX. qed. + +lemma lcX : lc X<:c> = oner<:c>. +proof. by rewrite degX /= polyXE. qed. + +lemma deg_ge1 (p : c poly) : (1 <= deg p) <=> (p <> poly0). +proof. by rewrite -deg_eq0 eqr_le ge0_deg /= (lerNgt _ 0) /#. qed. + +lemma deg_gt0 (p : c poly) : (0 < deg p) <=> (p <> poly0). +proof. by rewrite -deg_ge1 /#. qed. + +lemma deg_eq1 (p : c poly) : + (deg p = 1) <=> (exists a, a <> zero<:c> /\ p = polyC a). +proof. +split=> [eq1_degp|[a [nz_a ->>]]]; last first. ++ by apply: degP => //= => [|i ge1_i]; rewrite polyCE //= gtr_eqF /#. +have pC: forall i, 1 <= i => p.[i] = zero<:c>. ++ by move=> i ge1_i; apply: gedeg_coeff; rewrite eq1_degp. +exists p.[0]; split; last first. ++ apply/poly_eqP => i /ler_eqVlt -[<<-|]; first by rewrite polyCE. + by move=> gt0_i; rewrite polyCE gtr_eqF //= &(pC) /#. +apply: contraL eq1_degp => z_p0; suff ->: p = poly0 by rewrite deg0. +apply/poly_eqP=> i; rewrite poly0E => /ler_eqVlt [<<-//|]. +by move=> gt0_i; apply: pC => /#. +qed. + +lemma lc_eq0 (p : c poly) : (lc p = zero<:c>) <=> (p = poly0). +proof. +case: (p = poly0) => [->|] /=; first by rewrite lc0. +rewrite -deg_eq0 eqr_le ge0_deg /= -ltrNge => gt0_deg. +pose P i := forall j, (i <= j)%Int => p.[j] = zero<:c>. +apply/negP => zp; have h: 0 <= deg p - 1 < argmin idfun P. ++ rewrite /P /argmin -/(deg p); smt(ge0_deg). +have := argmin_min idfun P (deg p - 1) h. +move=> @/idfun /= j /ler_eqVlt [<<-//| ltj]. +by apply: gedeg_coeff => /#. +qed. + +(* -------------------------------------------------------------------- *) +(* Degree of additive operations. *) +(* -------------------------------------------------------------------- *) +lemma degN (p : c poly) : deg (-p) = deg p. +proof. +rewrite /deg; congr; apply/fun_ext => /= i; apply/eq_iff. +by split=> + j - /(_ j); rewrite polyNE oppr_eq0. +qed. + +lemma lcN (p : c poly) : lc (-p) = - lc p. +proof. by rewrite degN polyNE. qed. + +lemma degD (p q : c poly) : deg (p + q) <= max (deg p) (deg q). +proof. +apply: deg_leP; [by smt(ge0_deg) | move=> i /ler_maxrP[le1 le2]]. +by rewrite polyDE !gedeg_coeff ?addr0. +qed. + +lemma degB (p q : c poly) : deg (p - q) <= max (deg p) (deg q). +proof. by rewrite -(degN q) &(degD). qed. + +lemma degDl (p q : c poly) : deg q < deg p => deg (p + q) = deg p. +proof. +move=> le_pq; have gt0_p: 0 < deg p. +- by apply/(ler_lt_trans _ _ _ _ le_pq)/ge0_deg. +apply: degP=> //. +- rewrite polyDE (gedeg_coeff q) 1:/#. + by rewrite addr0 lc_eq0 -deg_eq0 gtr_eqF. +- move=> i le_pi; rewrite polyDE !gedeg_coeff ?addr0 //. + by apply/ltrW/(ltr_le_trans _ _ _ le_pq). +qed. + +lemma lcDl (p q : c poly) : deg q < deg p => lc (p + q) = lc p. +proof. +move=> ^lt_pq /degDl ->; rewrite polyDE. +by rewrite addrC gedeg_coeff ?add0r //#. +qed. + +lemma degDr (p q : c poly) : deg p < deg q => deg (p + q) = deg q. +proof. by move=> h; rewrite (addrC<:c poly> p q); apply degDl. qed. + +lemma lcDr (p q : c poly) : deg p < deg q => lc (p + q) = lc q. +proof. by move=> h; rewrite (addrC<:c poly> p q); apply lcDl. qed. + +(* -------------------------------------------------------------------- *) +(* Multiplicative degree. *) +(* -------------------------------------------------------------------- *) +lemma mul_lc (p q : c poly) : + lc p * lc q = (p * q).[deg p + deg q - 2]. +proof. +case: (p = poly0) => [->|nz_p]. +- by rewrite mul0r<:c poly> !poly0E mul0r. +case: (q = poly0) => [->|nz_q]. +- by rewrite polyM_mulrC polyM_mul0r !poly0E mulr0. +have ->: deg p + deg q - 2 = (deg p - 1) + (deg q - 1) by ring. +pose cp := deg p - 1; pose cq := deg q - 1. +rewrite polyME (bigD1 _ _ cp) ?range_uniq //=. +- rewrite mem_range subr_ge0 deg_ge1 nz_p /= -addrA. + by rewrite ltr_addl ltzS /cq subr_ge0 deg_ge1. +rewrite addrAC subrr /= big_seq_cond big1 ?addr0 //=. +move=> i [/mem_range [ge0_i lt] @/predC1 nei]. +case: (i < deg p) => [lt_ip| /lerNgt le_pi]; last first. +- by rewrite gedeg_coeff // mul0r. +by rewrite (gedeg_coeff q) ?mulr0 //#. +qed. + +(* -------------------------------------------------------------------- *) +lemma degM_le (p q : c poly) : p <> poly0 => q <> poly0 => + deg (p * q) + 1 <= deg p + deg q. +proof. +move=> nz_p nz_q; rewrite addrC -ler_subr_addl &(deg_leP). +- by move: nz_p nz_q; rewrite -!deg_eq0 !eqr_le !ge0_deg /= -!ltrNge /#. +move=> i lei; rewrite polyME big_seq big1 //=. +move=> j /mem_range [ge0_j /ltzS le_ij]. +case: (j < deg p) => [lt_jp|/lerNgt le_pk]. +- by rewrite mulrC gedeg_coeff ?mul0r //#. +- by rewrite gedeg_coeff ?mul0r //#. +qed. + +(* -------------------------------------------------------------------- *) +lemma degM_proper (p q : c poly) : + lc p * lc q <> zero<:c> => deg (p * q) = (deg p + deg q) - 1. +proof. +case: (p = poly0) => [->|nz_p]; first by rewrite lc0 !mul0r. +case: (q = poly0) => [->|nz_q]; first by rewrite lc0 !mulr0. +move=> nz_lc. +have ub := degM_le _ _ nz_p nz_q. +have lb : deg p + deg q - 1 <= deg (p * q). +- rewrite lerNgt /=; apply/negP => lt_pq. + apply nz_lc; rewrite mul_lc gedeg_coeff //#. +smt(). +qed. + +(* -------------------------------------------------------------------- *) +lemma lcM_proper (p q : c poly) : + lc p * lc q <> zero<:c> => lc (p * q) = lc p * lc q. +proof. by move=> reg; rewrite degM_proper //= -mul_lc. qed. + +(* -------------------------------------------------------------------- *) +lemma degZ_le (a : c) (p : c poly) : deg (a ** p) <= deg p. +proof. +case: (a = zero<:c>) => [->|nz_a]; 1: by rewrite scale0p deg0 ge0_deg. +case: (p = poly0) => [->|nz_p]; 1: by rewrite scalep0 deg0. +have nz_cp : polyC a <> poly0. +- by apply/negP => /(congr1 deg); rewrite deg0 degC nz_a. +rewrite scalepE -(ler_add2r 1); move/ler_trans: (degM_le _ _ nz_cp nz_p). +by apply; rewrite degC nz_a /= addrC. +qed. + +(* -------------------------------------------------------------------- *) +lemma degZ_lreg (a : c) (p : c poly) : lreg a => deg (a ** p) = deg p. +proof. +case: (p = poly0) => [->|^nz_p]; 1: by rewrite scalep0 deg0. +rewrite -deg_gt0 => gt0_dp lreg_a; apply/degP => // => [|i gei]. +- by rewrite polyZE mulrI_eq0 // lc_eq0. +- by rewrite gedeg_coeff // &(ler_trans (deg p)) // &(degZ_le). +qed. + +(* -------------------------------------------------------------------- *) +lemma lcZ_lreg (a : c) (p : c poly) : lreg a => lc (a ** p) = a * lc p. +proof. by move=> reg_a; rewrite degZ_lreg // polyZE. qed. + +(* -------------------------------------------------------------------- *) +(* polyXn / [exp X i] theory. *) +(* -------------------------------------------------------------------- *) +lemma polyCX (a : c) i : 0 <= i => exp (polyC a) i = polyC (exp a i). +proof. +elim: i => [|i ge0_i ih]; first by rewrite !expr0. +by rewrite !exprS // ih polyCM. +qed. + +(* -------------------------------------------------------------------- *) +lemma degXn_le (p : c poly) i : + p <> poly0 => 0 <= i => deg (exp p i) <= i * (deg p - 1) + 1. +proof. +move=> nz_p; elim: i => [|i ge0_i ih]; first by rewrite !expr0 deg1. +rewrite exprS // mulrDl /= addrAC !addrA ler_subr_addl (addrC<:int> 1). +case: (exp p i = poly0) => [->|nz_pX]. +- by rewrite mulr0 deg0 /=; rewrite -deg_gt0 in nz_p => /#. +apply: (ler_trans (deg p + deg (exp p i))); 1: by apply: degM_le. +by rewrite addrC &(ler_add2r). +qed. + +(* -------------------------------------------------------------------- *) +lemma lreg_lc (p : c poly) : lreg (lc p) => lreg p. +proof. +move/mulrI_eq0=> reg_p; apply/mulrI0_lreg => q. +apply: contraLR=> nz_q; rewrite -lc_eq0. +by rewrite lcM_proper reg_p lc_eq0. +qed. + +(* -------------------------------------------------------------------- *) +lemma degXn_proper (p : c poly) i : + lreg (lc p) => 0 <= i => deg (exp p i) = i * (deg p - 1) + 1. +proof. +move=> lreg_p; elim: i => [|i ge0_i ih]; first by rewrite expr0 deg1. +rewrite exprS // degM_proper; last by rewrite ih #ring. +by rewrite mulrI_eq0 // lc_eq0 lreg_neq0 // &(lregXn) // &(lreg_lc). +qed. + +(* -------------------------------------------------------------------- *) +lemma lcXn_proper (p : c poly) i : + lreg (lc p) => 0 <= i => lc (exp p i) = exp (lc p) i. +proof. +move=> reg_p; elim: i => [|i ge0_i ih]; 1: by rewrite !expr0 lc1. +rewrite !exprS // degM_proper /=; last by rewrite -mul_lc ih. +by rewrite mulrI_eq0 // lreg_neq0 // ih lregXn. +qed. + +(* -------------------------------------------------------------------- *) +lemma deg_polyXn i : 0 <= i => deg (exp X<:c> i) = i + 1. +proof. +move=> ge0_i; rewrite degXn_proper //. +- by rewrite lcX &(lreg1). +- by rewrite degX #ring. +qed. + +(* -------------------------------------------------------------------- *) +lemma lc_polyXn i : 0 <= i => lc (exp X<:c> i) = oner<:c>. +proof. +move=> ge0_i; rewrite lcXn_proper ?lcX //. +- by apply: lreg1. +- by rewrite expr1z. +qed. + +(* -------------------------------------------------------------------- *) +lemma deg_polyXnDC i (a : c) : 0 < i => deg (exp X<:c> i + polyC a) = i + 1. +proof. by move=> ge0_i; rewrite degDl 1?degC deg_polyXn 1:ltrW //#. qed. + +(* -------------------------------------------------------------------- *) +lemma lc_polyXnDC i (a : c) : 0 < i => lc (exp X<:c> i + polyC a) = oner<:c>. +proof. +move=> gti_0; rewrite lcDl ?lc_polyXn // -1:ltrW //. +- by rewrite degC deg_polyXn 1:ltrW //#. +qed. + +(* -------------------------------------------------------------------- *) +lemma polyXnE i k : + 0 <= i => (exp X<:c> i).[k] = if k = i then oner<:c> else zero<:c>. +proof. +move=> ge0_i; elim: i ge0_i k => [|i ge0_i ih] k. +- by rewrite expr0 polyCE. +- by rewrite exprS // (mulrC<:c poly>) polyMXE ih /#. +qed. + +(* -------------------------------------------------------------------- *) +(* Sums of polys. *) +(* -------------------------------------------------------------------- *) +lemma polysumE ['a] (P : 'a -> bool) (F : 'a -> c poly) (s : 'a list) k : + (bigA P F s).[k] = bigA P (fun i => (F i).[k]) s. +proof. +elim: s => /= [|x s ih]; first by rewrite !big_nil poly0E. +rewrite !big_cons -ih /=. +by rewrite -polyDE -(fun_if (fun q : c poly => q.[k])). +qed. + +(* -------------------------------------------------------------------- *) +lemma polyE (p : c poly) : + p = bigiA predT (fun i => p.[i] ** exp X<:c> i) 0 (deg p). +proof. +apply/poly_eqP=> i ge0_i; rewrite polysumE /=; case: (i < deg p). +- move=> lt_i_dp; rewrite (bigD1 _ _ i) ?(mem_range, range_uniq) //=. + rewrite !(coeffpE, polyXnE) //= mulr1 big1_seq ?addr0 //=. + move=> @/predC1 j [ne_ji /mem_range [ge0_j _]]. + by rewrite !(coeffpE, polyXnE) // (eq_sym i j) ne_ji /= mulr0. +- move=> /lerNgt ge_i_dp; rewrite gedeg_coeff //. + rewrite big_seq big1 //= => j /mem_range [ge0_j lt_j]. + by rewrite !(coeffpE, polyXnE) // (_ : i <> j) ?mulr0 //#. +qed. + +(* -------------------------------------------------------------------- *) +lemma polywE n (p : c poly) : deg p <= n => + p = bigiA predT (fun i => p.[i] ** exp X<:c> i) 0 n. +proof. +move=> le_pn; rewrite (big_cat_int (deg p)) // ?ge0_deg. +rewrite {1}polyE; pose r := bigA _ _ _. +pose d := bigA _ _ _; suff ->: d = poly0. +- by apply/poly_eqP=> i ge0_i; rewrite polyDE poly0E addr0. +rewrite /d big_seq big1 => //= i /mem_range [gei _]. +by rewrite gedeg_coeff // scale0p. +qed. + +(* -------------------------------------------------------------------- *) +lemma deg_sum ['a] (P : 'a -> bool) (F : 'a -> c poly) (r : 'a list) k : + 0 <= k + => (forall x, P x => deg (F x) <= k) + => deg (bigA P F r) <= k. +proof. +move=> ge0_k le; elim: r => [|x r ih]; 1: by rewrite big_nil deg0. +rewrite big_cons; case: (P x) => // Px. +by rewrite &(ler_trans _ _ _ (degD _ _)) ler_maxrP ih le. +qed. + +(* -------------------------------------------------------------------- *) +(* Polynomial evaluation. *) +(* -------------------------------------------------------------------- *) +op peval (p : c poly) (a : c) = + bigiA<:c> predT (fun i => p.[i] * exp a i) 0 (deg p + 1). + +abbrev root (p : c poly) (a : c) = peval p a = zero<:c>. + +(* -------------------------------------------------------------------- *) +(* polyL: build a polynomial from a coefficient list. *) +(* -------------------------------------------------------------------- *) +op prepolyL (a : c list) : int -> c = fun i => nth zero<:c> a i. + +lemma isprepolyL a : ispoly (prepolyL a). +proof. +split=> [i lt0_i|]; first by rewrite /prepolyL nth_neg. +exists (size a) => i gti; rewrite /prepolyL nth_out //. +by apply/negP => -[_]; rewrite ltrNge /= ltrW. +qed. + +op polyL (a : c list) : c poly = to_polyd (prepolyL a). + +lemma polyLE a i : (polyL a).[i] = nth zero<:c> a i. +proof. by rewrite coeffE 1:isprepolyL. qed. + +lemma degL_le a : deg (polyL a) <= size a. +proof. +apply: deg_leP; first exact: size_ge0. +by move=> i gei; rewrite polyLE nth_out //#. +qed. + +lemma degL a : + last zero<:c> a <> zero<:c> => deg (polyL a) = size a. +proof. +move=> nz; apply/degP. +- by case: a nz => //= x s _; rewrite addrC ltzS size_ge0. +- by rewrite polyLE nth_last. +- move=> i sza; rewrite gedeg_coeff //. + by apply: (ler_trans (size a)) => //; apply: degL_le. +qed. + +lemma inj_polyL a1 a2 : + size a1 = size a2 => polyL a1 = polyL a2 => a1 = a2. +proof. +move=> eq_sz /poly_eqP eq; apply: (eq_from_nth zero<:c>)=> //. +by move=> i [+ _] - /eq; rewrite !polyLE. +qed. + +lemma surj_polyL p n : + deg p <= n => exists s, size s = n /\ p = polyL s. +proof. +move=> len; exists (map (fun i => p.[i]) (range 0 n)); split. +- by rewrite size_map size_range /=; smt(ge0_deg). +apply/poly_eqP=> i ge0_i; rewrite polyLE; case: (i < n). +- by move=> lt_in; rewrite (nth_map 0) ?size_range ?nth_range //#. +- rewrite ltrNge /= => le_ni; rewrite gedeg_coeff // 1:/#. + by rewrite nth_out // size_map size_range /#. +qed. + +end section. + +(* ==================================================================== *) +(* Phase 7: idomain extension. Mirrors [theories/algebra/Poly.ec:Poly] *) +(* (the idomain-coefficient phase). Adds the multiplicativity of [deg] *) +(* and [lc], the no-zero-divisor property, and the structural *) +(* characterisation lemmas [unitE]/[polyVE] bridging the choiceb-based *) +(* [poly_unit]/[poly_invr] (committed at Phase 5) to the structural *) +(* "deg=1 with invertible constant" form available when [c : idomain]. *) +(* ==================================================================== *) +section. +declare type c <: idomain. + +(* -------------------------------------------------------------------- *) +lemma degM (p q : c poly) : p <> poly0 => q <> poly0 => + deg (p * q) = deg p + deg q - 1. +proof. +rewrite -!lc_eq0 -!lregP => reg_p reg_q. +by rewrite &(degM_proper) mulf_eq0 negb_or -!lregP. +qed. + +(* -------------------------------------------------------------------- *) +lemma lcM (p q : c poly) : lc (p * q) = lc p * lc q. +proof. +case: (p = poly0) => [->|nz_p]; first by rewrite polyM_mul0r !lc0 mul0r. +case: (q = poly0) => [->|nz_q]. +- by rewrite polyM_mulrC polyM_mul0r !lc0 mulr0. +by rewrite lcM_proper // mulf_eq0 !lc_eq0 !(nz_p, nz_q). +qed. + +(* -------------------------------------------------------------------- *) +(* No zero divisors at the poly level (the [mulf_eq0] axiom one would *) +(* need to register [idomain with ('c poly)]). *) +(* -------------------------------------------------------------------- *) +lemma polyM_mulf_eq0 (p q : c poly) : + p * q = poly0 <=> p = poly0 \/ q = poly0. +proof. +split; last by case=> ->; rewrite ?polyM_mul0r // polyM_mulrC polyM_mul0r. +apply: contraLR; rewrite negb_or => -[nz_p nz_q]; apply/negP. +move/(congr1 (fun r : c poly => deg r + 1)) => /=; rewrite deg0 degM //=. +by rewrite gtr_eqF // -lez_add1r ler_add deg_ge1. +qed. + +(* -------------------------------------------------------------------- *) +(* Structural characterisation of [poly_unit] / [poly_invr] when *) +(* [c : idomain]. Bridges the choiceb-based forms committed at Phase 5 *) +(* to the deg=1-with-invertible-constant form usable in proofs. The *) +(* underlying ops (poly_unit, poly_invr) remain as registered; *) +(* downstream code rewrites with these equivalences. *) +(* -------------------------------------------------------------------- *) +lemma unitE (p : c poly) : + poly_unit p <=> deg p = 1 /\ unit p.[0]. +proof. +rewrite /poly_unit; split. +- case=> q pMqE. + have nz_p : p <> poly0. + - apply/negP=> ->>; have := pMqE; rewrite polyM_mulrC polyM_mul0r => /eq_sym. + by move/(congr1 (fun r : c poly => r.[0])) => /=; + rewrite poly0E polyCE /=; smt(oner_neq0). + have nz_q : q <> poly0. + - apply/negP=> ->>; have := pMqE; rewrite polyM_mul0r => /eq_sym. + by move/(congr1 (fun r : c poly => r.[0])) => /=; + rewrite poly0E polyCE /=; smt(oner_neq0). + have /(congr1 deg) : polyM q p = poly1 by exact pMqE. + rewrite deg1 degM //= => sum_eq. + have ge1_p : 1 <= deg p by rewrite deg_ge1. + have ge1_q : 1 <= deg q by rewrite deg_ge1. + have [dq_eq dp_eq] : deg q = 1 /\ deg p = 1 by smt(). + split=> //. + move/poly_eqP: pMqE => /(_ 0 _) //; rewrite polyCE /=. + by rewrite polyME big_int1 /= => /unitP. +- case=> dp_eq1 unit_p0; case/deg_eq1: dp_eq1 => a [nz_a ->>]. + exists (polyC (invr a)); apply/poly_eqP=> i ge0_i. + rewrite polyCE polyME; case: (i = 0) => [->>|ne0_i] /=. + - rewrite big_int1 /= !polyCE /= mulVr //. + by move: unit_p0; rewrite polyCE. + rewrite big_seq big1 ?addr0 //= => j /mem_range [ge0_j _]. + rewrite !polyCE; case: (j = 0) => [->>/=|/= _]. + - by rewrite ne0_i /= mulr0. + - by rewrite mul0r. +qed. + +(* -------------------------------------------------------------------- *) +(* Structural value of [poly_invr] for unit polynomials over an + idomain coefficient: [poly_invr (polyC a) = polyC (invr a)] when + [unit a]. The choiceb's witness [q : q * polyC a = poly1] is + uniquely [polyC (invr a)] modulo invertibility, which suffices for + pointwise equality. *) +(* -------------------------------------------------------------------- *) +lemma polyVE (a : c) : unit a => poly_invr (polyC a) = polyC (invr a). +proof. +move=> ua; rewrite /poly_invr. +have ex_q : exists q, polyM q (polyC a) = poly_one<:c>. +- exists (polyC (invr a)); apply/poly_eqP=> i ge0_i. + rewrite polyME /poly_one polyCE; case: (i = 0) => [->>|nei] /=. + - by rewrite big_int1 /= !polyCE /= mulVr. + rewrite big_seq big1 ?addr0 //= => j /mem_range [ge0_j _]. + rewrite !polyCE; case: (j = 0) => [->>/=|/= _]. + - by rewrite nei /= mulr0. + - by rewrite mul0r. +have := choicebP (fun q => polyM q (polyC a) = poly_one<:c>) (polyC a) ex_q. +move=> /= choice_eq. +(* Both [choiceb …] and [polyC (invr a)] are left inverses of [polyC a]; + uniqueness via no-zero-divisors yields equality. *) +pose q := choiceb (fun q => polyM q (polyC a) = poly_one<:c>) (polyC a). +have qE : polyM q (polyC a) = poly_one<:c> by exact choice_eq. +apply/poly_eqP=> i ge0_i. +have polyC_invr_eq : polyM (polyC (invr a)) (polyC a) = poly_one<:c>. +- apply/poly_eqP=> j ge0_j; rewrite polyME /poly_one polyCE. + case: (j = 0) => [->>|nej] /=. + - by rewrite big_int1 /= !polyCE /= mulVr. + rewrite big_seq big1 ?addr0 //= => k /mem_range [ge0_k _]. + rewrite !polyCE; case: (k = 0) => [->>/=|/= _]. + - by rewrite nej /= mulr0. + - by rewrite mul0r. +have eq2 : polyM q (polyC a) = polyM (polyC (invr a)) (polyC a) + by rewrite qE -polyC_invr_eq. +(* Cancel [polyC a] on the right: it has [unit] coeff, so it's [lreg]. *) +have nz_a : a <> zero<:c>. +- apply/negP=> a0; have h := mulVr a ua; rewrite a0 mulr0 in h. + by move: h => /eq_sym; smt(oner_neq0). +have lreg_pCa : lreg (polyC a). +- apply lreg_lc; rewrite lcC; apply/lregP/nz_a. +have inj_pCa : injective (fun y : c poly => polyM y (polyC a)). +- by move=> x y; rewrite (polyM_mulrC x) (polyM_mulrC y) => /lreg_pCa. +have q_eq : q = polyC (invr a) by apply: inj_pCa. +by rewrite q_eq. +qed. + +end section. diff --git a/examples/tcalgebra/TcPolySmokeTest.ec b/examples/tcalgebra/TcPolySmokeTest.ec new file mode 100644 index 0000000000..89b0bc7af6 --- /dev/null +++ b/examples/tcalgebra/TcPolySmokeTest.ec @@ -0,0 +1,99 @@ +(* ==================================================================== *) +(* Smoke test for TcPoly: instantiate the parametric polynomial *) +(* library at carrier [int] (which is registered as [idomain] via *) +(* TcInt) and exercise representative lemmas from each phase. Confirms *) +(* the registered instances flow end-to-end through TC reduction. *) +(* ==================================================================== *) +require import AllCore List. +require import TcMonoid TcRing TcBigop TcBigalg TcInt. +require import TcPoly. + +(* -------------------------------------------------------------------- *) +(* Phase 1-2: constructors / coefficient formulas. *) +lemma test_polyCE (a : int) (k : int) : + (polyC<:int> a).[k] = if k = 0 then a else 0. +proof. by rewrite polyCE. qed. + +lemma test_polyXE (k : int) : + (X<:int>).[k] = if k = 1 then 1 else 0. +proof. by rewrite polyXE. qed. + +(* -------------------------------------------------------------------- *) +(* Phase 4: multiplication on int polys. *) +lemma test_mulrA (p q r : int poly) : + polyM p (polyM q r) = polyM (polyM p q) r. +proof. by apply polyM_mulrA. qed. + +lemma test_mulrC (p q : int poly) : polyM p q = polyM q p. +proof. by apply polyM_mulrC. qed. + +(* -------------------------------------------------------------------- *) +(* Phase 6a: degree arithmetic on int polys. *) +lemma test_degC (a : int) : + deg (polyC<:int> a) = if a = 0 then 0 else 1. +proof. by rewrite degC. qed. + +lemma test_deg0 : deg poly0<:int> = 0. +proof. by rewrite deg0. qed. + +lemma test_deg1 : deg poly1<:int> = 1. +proof. by rewrite deg1. qed. + +lemma test_degX : deg X<:int> = 2. +proof. by rewrite degX. qed. + +(* -------------------------------------------------------------------- *) +(* Phase 6c: polyXn / X^i theory. *) +lemma test_deg_polyXn (i : int) : 0 <= i => deg (exp X<:int> i) = i + 1. +proof. by apply deg_polyXn. qed. + +lemma test_lc_polyXn (i : int) : 0 <= i => lc (exp X<:int> i) = 1. +proof. by apply lc_polyXn. qed. + +(* -------------------------------------------------------------------- *) +(* Phase 7: idomain-only lemmas — multiplicativity of [deg] / [lc]. *) +lemma test_degM (p q : int poly) : + p <> poly0 => q <> poly0 => deg (polyM p q) = deg p + deg q - 1. +proof. by apply degM. qed. + +lemma test_lcM (p q : int poly) : lc (polyM p q) = lc p * lc q. +proof. by apply lcM. qed. + +lemma test_polyM_mulf_eq0 (p q : int poly) : + polyM p q = poly0 <=> p = poly0 \/ q = poly0. +proof. by apply polyM_mulf_eq0. qed. + +(* -------------------------------------------------------------------- *) +(* Concrete computation through the convolution: coefficient at index 0 + of [(X + polyC 1) * (X + polyC (-1))] equals -1. Spot-check that + [polyM] reduces correctly through the registered comring chain. *) +lemma test_polyM_at_0 : + (polyM<:int> (polyD X (polyC 1)) (polyD X (polyC (-1)))).[0] = -1. +proof. +rewrite polyME big_int1 /=. +by rewrite !(polyDE, polyXE, polyCE) /= !(mul0r, mulr0, addr0, mul1r, add0r). +qed. + +(* -------------------------------------------------------------------- *) +(* polyL constructor on int. *) +lemma test_polyLE (xs : int list) (k : int) : + (polyL xs).[k] = nth 0 xs k. +proof. by rewrite polyLE. qed. + +(* -------------------------------------------------------------------- *) +(* Class lemmas at carrier [int poly] — exercises the parametric Path B *) +(* path through the unifier's flush + matcher's drain. *) +(* -------------------------------------------------------------------- *) +lemma test_addrC_at_int_poly (p q : int poly) : p + q = q + p. +proof. by apply (addrC<:int poly>). qed. + +lemma test_addrA_at_int_poly (p q r : int poly) : + p + (q + r) = (p + q) + r. +proof. by apply (addrA<:int poly>). qed. + +lemma test_mulrC_at_int_poly (p q : int poly) : p * q = q * p. +proof. by apply (mulrC<:int poly>). qed. + +lemma test_mulrA_at_int_poly (p q r : int poly) : + p * (q * r) = (p * q) * r. +proof. by apply (mulrA<:int poly>). qed. diff --git a/examples/tcalgebra/TcRing.ec b/examples/tcalgebra/TcRing.ec new file mode 100644 index 0000000000..e7ac5db51f --- /dev/null +++ b/examples/tcalgebra/TcRing.ec @@ -0,0 +1,876 @@ +pragma +implicits. + +(* -------------------------------------------------------------------- *) +require import Core Int. +require import TcMonoid. + +(* ==================================================================== *) +(* Additive group: extends [addmonoid] with negation. Carrier of all + ZModule lemmas in the original [theories/algebra/Ring.ec]. *) +(* ==================================================================== *) +type class addgroup <: addmonoid = { + op [-] : addgroup -> addgroup + + axiom addrN : right_inverse zero<:addgroup> [-] (+)<:addgroup> +}. + +(* -------------------------------------------------------------------- *) +section. +declare type t <: addgroup. + +(* Re-export the inherited addmonoid axioms under the conventional + ring-theoretic names. *) +lemma addrA: associative (+)<:t>. +proof. exact addmA. qed. + +lemma addrC: commutative (+)<:t>. +proof. exact addmC. qed. + +lemma add0r: left_id zero<:t> (+)<:t>. +proof. exact add0m. qed. + +(* The original [Ring.ec] takes [addNr] as the additive group axiom and + derives [addrN] from it; here we take [addrN] (right inverse) and + derive [addNr] (left inverse) instead. *) +lemma addNr: left_inverse zero<:t> [-] (+)<:t>. +proof. by move=> x; rewrite addrC addrN. qed. + +abbrev (-) (x y : t) = x + -y. + +lemma addr0: right_id zero<:t> (+). +proof. exact addm0. qed. + +lemma addrCA: left_commutative (+)<:t>. +proof. exact addmCA. qed. + +lemma addrAC: right_commutative (+)<:t>. +proof. exact addmAC. qed. + +lemma addrACA: interchange (+)<:t> (+). +proof. exact addmACA. qed. + +lemma subrr (x : t): x - x = zero. +proof. by rewrite addrN. qed. + +hint simplify subrr. + +lemma addKr: left_loop ([-]<:t>) (+). +proof. by move=> x y; rewrite addrA addNr add0r. qed. + +lemma addNKr: rev_left_loop ([-]<:t>) (+). +proof. by move=> x y; rewrite addrA addrN add0r. qed. + +lemma addrK: right_loop ([-]<:t>) (+). +proof. by move=> x y; rewrite -addrA addrN addr0. qed. + +lemma addrNK: rev_right_loop ([-]<:t>) (+). +proof. by move=> x y; rewrite -addrA addNr addr0. qed. + +lemma subrK (x y : t): (x - y) + y = x. +proof. by rewrite addrNK. qed. + +lemma addrI: right_injective (+)<:t>. +proof. by move=> x y z h; rewrite -(@addKr x z) -h addKr. qed. + +lemma addIr: left_injective (+)<:t>. +proof. by move=> x y z h; rewrite -(@addrK x z) -h addrK. qed. + +lemma opprK: involutive ([-]<:t>). +proof. by move=> x; apply (@addIr (-x)); rewrite addNr addrN. qed. + +lemma oppr_inj : injective ([-]<:t>). +proof. by move=> x y eq; apply/(addIr (-x)); rewrite subrr eq subrr. qed. + +lemma oppr0 : -zero<:t> = zero. +proof. by rewrite -(@addr0 (-zero)) addNr. qed. + +lemma oppr_eq0 (x : t) : (- x = zero) <=> (x = zero). +proof. by rewrite (inv_eq opprK) oppr0. qed. + +lemma subr0 (x : t): x - zero = x. +proof. by rewrite oppr0 addr0. qed. + +lemma sub0r (x : t): zero - x = - x. +proof. by rewrite add0r. qed. + +lemma opprD (x y : t): -(x + y) = -x + -y. +proof. by apply (@addrI (x + y)); rewrite addrA addrN addrAC addrK addrN. qed. + +lemma opprB (x y : t): -(x - y) = y - x. +proof. by rewrite opprD opprK addrC. qed. + +lemma subrACA: interchange (fun (x y : t) => x - y) (+). +proof. by move=> x y z u; rewrite addrACA opprD. qed. + +lemma subr_eq (x y z : t): + (x - z = y) <=> (x = y + z). +proof. +move: (can2_eq (fun x, x - z) (fun x, x + z) _ _ x y) => //=. ++ by move=> {x} x /=; rewrite addrNK. ++ by move=> {x} x /=; rewrite addrK. +qed. + +lemma subr_eq0 (x y : t): (x - y = zero) <=> (x = y). +proof. by rewrite subr_eq add0r. qed. + +lemma addr_eq0 (x y : t): (x + y = zero) <=> (x = -y). +proof. by rewrite -(@subr_eq0 x) opprK. qed. + +lemma eqr_opp (x y : t): (- x = - y) <=> (x = y). +proof. by apply/(@can_eq _ _ opprK x y). qed. + +lemma eqr_oppLR (x y : t) : (- x = y) <=> (x = - y). +proof. by apply/(@inv_eq _ opprK x y). qed. + +lemma eqr_sub (x y z u : t) : (x - y = z - u) <=> (x + u = z + y). +proof. +rewrite -{1}(addrK u x) -{1}(addrK y z) -!addrA. +by rewrite (addrC (-u)) !addrA; split=> [/addIr /addIr|->//]. +qed. + +lemma subr_add2r (z x y : t): (x + z) - (y + z) = x - y. +proof. by rewrite opprD addrACA addrN addr0. qed. +end section. + +(* -------------------------------------------------------------------- *) +(* [intmul x n] is [n] copies of [x] folded with [+]; for negative [n] + it is [-(intmul x (-n))]. Foundational for [ofint] and for + characterizing ring exponents. *) +op intmul ['a <: addgroup] (x : 'a) (n : int) = + if n < 0 + then -(iterop (-n) (+) x zero) + else (iterop n (+) x zero). + +(* -------------------------------------------------------------------- *) +section. +declare type t <: addgroup. + +lemma intmulpE (x : t) (c : int) : 0 <= c => + intmul x c = iterop c (+) x zero. +proof. by rewrite /intmul lezNgt => ->. qed. + +lemma mulr0z (x : t): intmul x 0 = zero. +proof. by rewrite /intmul /= iterop0. qed. + +lemma mulr1z (x : t): intmul x 1 = x. +proof. by rewrite /intmul /= iterop1. qed. + +lemma mulr2z (x : t): intmul x 2 = x + x. +proof. by rewrite /intmul /= (@iteropS 1) // (@iterS 0) // iter0. qed. + +lemma mulrNz (x : t) (n : int): intmul x (-n) = -(intmul x n). +proof. +case: (n = 0)=> [->|nz_c]; first by rewrite oppz0 mulr0z oppr0. +rewrite /intmul oppz_lt0 oppzK ltz_def nz_c lezNgt /=. +by case: (n < 0); rewrite ?opprK. +qed. + +lemma mulrS (x : t) (n : int): 0 <= n => + intmul x (n+1) = x + intmul x n. +proof. +move=> ge0n; rewrite !intmulpE 1:addz_ge0 //. +by rewrite !iteropE iterS. +qed. + +lemma mulNrz (x : t) (n : int) : intmul (-x) n = - (intmul x n). +proof. +elim/intwlog: n => [n h| | n ge0_n ih]. ++ by rewrite -(@oppzK n) !(@mulrNz _ (- n)) h. ++ by rewrite !mulr0z oppr0. ++ by rewrite !mulrS // ih opprD. +qed. + +lemma mulNrNz (x : t) (n : int) : intmul (-x) (-n) = intmul x n. +proof. by rewrite mulNrz mulrNz opprK. qed. + +lemma mulrSz (x : t) (n : int) : intmul x (n + 1) = x + intmul x n. +proof. +case: (0 <= n) => [/mulrS ->//|]; rewrite -ltzNge => gt0_n. +case: (n = -1) => [->/=|]; 1: by rewrite mulrNz mulr1z mulr0z subrr. +move=> neq_n_N1; rewrite -!(@mulNrNz x). +rewrite (_ : -n = -(n+1) + 1) 1:/# mulrS 1:/#. +by rewrite addrA subrr add0r. +qed. + +lemma mulrDz (x : t) (n m : int) : intmul x (n + m) = intmul x n + intmul x m. +proof. +wlog: n m / 0 <= m => [wlog|]. ++ case: (0 <= m) => [/wlog|]; first by apply. + rewrite -ltzNge => lt0_m; rewrite (_ : n + m = -(-m - n)) 1:/#. + by rewrite mulrNz addzC wlog 1:/# !mulrNz -opprD opprK. +elim: m => /= [|m ge0_m ih]; first by rewrite mulr0z addr0. +by rewrite addzA !mulrSz ih addrCA. +qed. +end section. + +(* ==================================================================== *) +(* Commutative ring: addgroup + multiplicative commutative monoid + + distributivity. Multi-parent factory inheritance: comring inherits + from [addgroup] and from [mulmonoid] (with [idm := oner] and + [(+) := ( * )]). The locally-declared [oner] / [( * )] are aliases + for the inherited mulmonoid ops; the multiplicative + associativity / commutativity / left-id axioms ([mulrA] / [mulrC] + / [mul1r]) are kept as axioms in the class body so they're + available under conventional ring-theoretic names downstream. *) +(* ==================================================================== *) +type class comring <: addgroup & (mulmonoid with idm = oner, (+) = ( * )) = { + op oner : comring + op ( * ) : comring -> comring -> comring + op invr : comring -> comring + op unit : comring -> bool + + axiom oner_neq0 : oner <> zero<:comring> + axiom mulrA : associative ( * ) + axiom mulrC : commutative ( * ) + axiom mul1r : left_id oner ( * ) + axiom mulrDl : left_distributive ( * ) (+)<:comring> + axiom mulVr : left_inverse_in unit oner invr ( * ) + axiom unitP : forall (x y : comring), y * x = oner => unit x + axiom unitout : forall (x : comring), !unit x => invr x = x +}. + +(* -------------------------------------------------------------------- *) +section. +declare type t <: comring. + +abbrev (/) (x y : t) = x * (invr y). + +lemma mulr1: right_id oner<:t> ( * ). +proof. by move=> x; rewrite mulrC mul1r. qed. + +lemma mulrCA: left_commutative ( * )<:t>. +proof. by move=> x y z; rewrite !mulrA (@mulrC x y). qed. + +lemma mulrAC: right_commutative ( * )<:t>. +proof. by move=> x y z; rewrite -!mulrA (@mulrC y z). qed. + +lemma mulrACA: interchange ( * )<:t> ( * ). +proof. by move=> x y z u; rewrite -!mulrA (mulrCA y). qed. + +lemma mulrSl (x y : t) : (x + oner) * y = x * y + y. +proof. by rewrite mulrDl mul1r. qed. + +lemma mulrDr: right_distributive ( * )<:t> (+). +proof. by move=> x y z; rewrite mulrC mulrDl !(@mulrC _ x). qed. + +lemma mul0r: left_zero zero<:t> ( * ). +proof. by move=> x; apply: (@addIr (oner * x)); rewrite -mulrDl !add0r mul1r. qed. + +lemma mulr0: right_zero zero<:t> ( * ). +proof. by move=> x; apply: (@addIr (x * oner)); rewrite -mulrDr !add0r mulr1. qed. + +lemma mulrN (x y : t): x * (- y) = - (x * y). +proof. by apply: (@addrI (x * y)); rewrite -mulrDr !addrN mulr0. qed. + +lemma mulNr (x y : t): (- x) * y = - (x * y). +proof. by apply: (@addrI (x * y)); rewrite -mulrDl !addrN mul0r. qed. + +lemma mulrNN (x y : t): (- x) * (- y) = x * y. +proof. by rewrite mulrN mulNr opprK. qed. + +lemma mulN1r (x : t): (-oner) * x = -x. +proof. by rewrite mulNr mul1r. qed. + +lemma mulrN1 (x : t): x * -oner = -x. +proof. by rewrite mulrN mulr1. qed. + +lemma mulrBl: left_distributive ( * )<:t> (fun (x y : t) => x - y). +proof. by move=> x y z; rewrite mulrDl !mulNr. qed. + +lemma mulrBr: right_distributive ( * )<:t> (fun (x y : t) => x - y). +proof. by move=> x y z; rewrite mulrDr !mulrN. qed. + +(* -------------------------------------------------------------------- *) +(* Multiplicative-inverse / unit theory. *) +(* -------------------------------------------------------------------- *) + +lemma mulrV: right_inverse_in unit<:t> oner invr ( * ). +proof. by move=> x /mulVr; rewrite mulrC. qed. + +lemma divrr (x : t): unit x => x / x = oner. +proof. by apply/mulrV. qed. + +lemma invr_out (x : t): !unit x => invr x = x. +proof. by apply/unitout. qed. + +lemma unitrP (x : t): unit x <=> (exists y, y * x = oner). +proof. by split=> [/mulVr<- |]; [exists (invr x) | case=> y /unitP]. qed. + +lemma mulKr: left_loop_in unit<:t> invr ( * ). +proof. by move=> x un_x y; rewrite mulrA mulVr // mul1r. qed. + +lemma mulrK: right_loop_in unit<:t> invr ( * ). +proof. by move=> y un_y x; rewrite -mulrA mulrV // mulr1. qed. + +lemma mulVKr: rev_left_loop_in unit<:t> invr ( * ). +proof. by move=> x un_x y; rewrite mulrA mulrV // mul1r. qed. + +lemma mulrVK: rev_right_loop_in unit<:t> invr ( * ). +proof. by move=> y nz_y x; rewrite -mulrA mulVr // mulr1. qed. + +lemma mulrI: right_injective_in unit<:t> ( * ). +proof. by move=> x Ux; have /can_inj h := mulKr _ Ux. qed. + +lemma mulIr: left_injective_in unit<:t> ( * ). +proof. by move=> x /mulrI h y1 y2; rewrite !(@mulrC _ x) => /h. qed. + +lemma unitrE (x : t): unit x <=> (x / x = oner). +proof. +split=> [Ux|xx1]; 1: by apply/divrr. +by apply/unitrP; exists (invr x); rewrite mulrC. +qed. + +lemma invrK: involutive invr<:t>. +proof. +move=> x; case: (unit x)=> Ux; 2: by rewrite !invr_out. +rewrite -(mulrK _ Ux (invr (invr x))) -mulrA. +rewrite (@mulrC x) mulKr //; apply/unitrP. +by exists x; rewrite mulrV. +qed. + +lemma invr_inj: injective invr<:t>. +proof. by apply: (can_inj _ _ invrK). qed. + +lemma unitrV (x : t): unit (invr x) <=> unit x. +proof. by rewrite !unitrE invrK mulrC. qed. + +lemma unitr1: unit<:t> oner. +proof. by apply/unitrP; exists oner; rewrite mulr1. qed. + +lemma invr1: invr oner<:t> = oner. +proof. by rewrite -{2}(mulVr _ unitr1) mulr1. qed. + +lemma div1r (x : t) : oner / x = invr x. +proof. by rewrite mul1r. qed. + +lemma divr1 (x : t) : x / oner = x. +proof. by rewrite invr1 mulr1. qed. + +lemma unitr0: !unit zero<:t>. +proof. by apply/negP=> /unitrP [y]; rewrite mulr0 eq_sym oner_neq0. qed. + +lemma invr0: invr zero<:t> = zero. +proof. by rewrite invr_out ?unitr0. qed. + +lemma unitrN1: unit<:t> (-oner). +proof. by apply/unitrP; exists (-oner); rewrite mulrNN mulr1. qed. + +lemma invrN1: invr<:t> (-oner) = -oner. +proof. by rewrite -{2}(divrr unitrN1) mulN1r opprK. qed. + +lemma unitrMl (x y : t) : unit y => (unit (x * y) <=> unit x). +proof. +move=> uy; case: (unit x)=> /=; last first. ++ apply/contra=> uxy; apply/unitrP; exists (y * invr (x * y)). + apply/(mulrI (invr y)); first by rewrite unitrV. + rewrite !mulrA mulVr // mul1r; apply/(mulIr y)=> //. + by rewrite -mulrA mulVr // mulr1 mulVr. +move=> ux; apply/unitrP; exists (invr y * invr x). +by rewrite -!mulrA mulKr // mulVr. +qed. + +lemma unitrMr (x y : t) : unit x => (unit (x * y) <=> unit y). +proof. +move=> ux; split=> [uxy|uy]; last by rewrite unitrMl. +by rewrite -(mulKr _ ux y) unitrMl ?unitrV. +qed. + +lemma unitrM (x y : t) : unit (x * y) <=> (unit x /\ unit y). +proof. +case: (unit x) => /=; first by apply: unitrMr. +apply: contra => /unitrP[z] zVE; apply/unitrP. +by exists (y * z); rewrite mulrAC (@mulrC y) (@mulrC _ z). +qed. + +lemma unitrN (x : t) : unit (-x) <=> unit x. +proof. by rewrite -mulN1r unitrMr // unitrN1. qed. + +lemma invrM (x y : t) : unit x => unit y => invr (x * y) = invr y * invr x. +proof. +move=> Ux Uy; have Uxy: unit (x * y) by rewrite unitrMl. +by apply: (mulrI _ Uxy); rewrite mulrV ?mulrA ?mulrK ?mulrV. +qed. + +lemma invrN (x : t) : invr (- x) = - (invr x). +proof. +case: (unit x) => ux; last by rewrite !invr_out ?unitrN. +by rewrite -mulN1r invrM ?unitrN1 // invrN1 mulrN1. +qed. + +lemma invr_neq0 (x : t) : x <> zero => invr x <> zero. +proof. +move=> nx0; case: (unit x)=> Ux; last by rewrite invr_out ?Ux. +by apply/negP=> x'0; move: Ux; rewrite -unitrV x'0 unitr0. +qed. + +lemma invr_eq0 (x : t) : (invr x = zero) <=> (x = zero). +proof. by apply/iff_negb; split=> /invr_neq0; rewrite ?invrK. qed. + +lemma invr_eq1 (x : t) : (invr x = oner) <=> (x = oner). +proof. by rewrite (inv_eq invrK) invr1. qed. + +end section. + +(* -------------------------------------------------------------------- *) +(* Embedding of [int] into a [comring]: [ofint n = intmul oner n]. *) +op ofint ['a <: comring] (n : int) : 'a = intmul oner n. + +(* -------------------------------------------------------------------- *) +section. +declare type t <: comring. + +lemma ofint0 : ofint<:t> 0 = zero. +proof. by apply/mulr0z. qed. + +lemma ofint1 : ofint<:t> 1 = oner. +proof. by apply/mulr1z. qed. + +lemma ofintS (i : int) : 0 <= i => ofint<:t> (i + 1) = oner + ofint i. +proof. by apply/mulrS. qed. + +lemma ofintN (i : int) : ofint<:t> (-i) = - (ofint i). +proof. by apply/mulrNz. qed. + +(* -------------------------------------------------------------------- *) +(* Interaction between additive [intmul] and multiplicative [( * )]. *) +lemma mulrnAl (x y : t) (n : int) : 0 <= n => + (intmul x n) * y = intmul (x * y) n. +proof. +elim: n => [|n ge0n ih]; rewrite !(mulr0z, mulrS) ?mul0r //. +by rewrite mulrDl ih. +qed. + +lemma mulrnAr (x y : t) (n : int) : 0 <= n => + x * (intmul y n) = intmul (x * y) n. +proof. +elim: n => [|n ge0n ih]; rewrite !(mulr0z, mulrS) ?mulr0 //. +by rewrite mulrDr ih. +qed. + +lemma mulrzAl (x y : t) (z : int) : (intmul x z) * y = intmul (x * y) z. +proof. +case: (lezWP 0 z)=> [|_] le; first by rewrite mulrnAl. +by rewrite -oppzK mulrNz mulNr mulrnAl -?mulrNz // oppz_ge0. +qed. + +lemma mulrzAr (x y : t) (z : int) : x * (intmul y z) = intmul (x * y) z. +proof. +case: (lezWP 0 z)=> [|_] le; first by rewrite mulrnAr. +by rewrite -oppzK mulrNz mulrN mulrnAr -?mulrNz // oppz_ge0. +qed. + +lemma mul1r0z (x : t) : x * ofint 0 = zero. +proof. by rewrite ofint0 mulr0. qed. + +lemma mul1r1z (x : t) : x * ofint 1 = x. +proof. by rewrite ofint1 mulr1. qed. + +lemma mul1r2z (x : t) : x * ofint 2 = x + x. +proof. by rewrite /ofint mulr2z mulrDr mulr1. qed. + +lemma mulr_intl (x : t) (z : int) : (ofint z) * x = intmul x z. +proof. by rewrite mulrzAl mul1r. qed. + +lemma mulr_intr (x : t) (z : int) : x * (ofint z) = intmul x z. +proof. by rewrite mulrzAr mulr1. qed. +end section. + +(* -------------------------------------------------------------------- *) +(* Multiplicative exponentiation. Mirrors [intmul] on the additive side + but folds with [( * )] starting at [oner], inverting for negative + exponents. *) +op exp ['a <: comring] (x : 'a) (n : int) = + if n < 0 + then invr (iterop (-n) ( * ) x oner) + else iterop n ( * ) x oner. + +(* -------------------------------------------------------------------- *) +section. +declare type t <: comring. + +lemma expr0 (x : t) : exp x 0 = oner. +proof. by rewrite /exp /= iterop0. qed. + +lemma expr1 (x : t) : exp x 1 = x. +proof. by rewrite /exp /= iterop1. qed. + +(* Multiplicative analogue of [TcMonoid.iteropE], specialised for + [( * )] / [oner] (i.e. [iterop] folded over the mulmonoid view). *) +lemma mul_iteropE (n : int) (x : t) : + iterop n ( * ) x oner = iter n (( * ) x) oner. +proof. +elim/natcase n => [n le0_n|n ge0_n]. ++ by rewrite ?(iter0, iterop0). ++ by rewrite iterSr // mulr1 iteropS. +qed. + +lemma exprS (x : t) (i : int) : 0 <= i => exp x (i+1) = x * (exp x i). +proof. +move=> ge0i; rewrite /exp !ltzNge ge0i addz_ge0 //=. +by rewrite !mul_iteropE iterS. +qed. + +lemma expr_pred (x : t) (i : int) : 0 < i => exp x i = x * (exp x (i - 1)). +proof. smt(exprS). qed. + +lemma exprSr (x : t) (i : int) : 0 <= i => exp x (i+1) = (exp x i) * x. +proof. by move=> ge0_i; rewrite exprS // mulrC. qed. + +lemma expr2 (x : t) : exp x 2 = x * x. +proof. by rewrite (@exprS _ 1) // expr1. qed. + +lemma exprN (x : t) (i : int) : exp x (-i) = invr (exp x i). +proof. +case: (i = 0) => [->|]; first by rewrite oppz0 expr0 invr1. +rewrite /exp oppz_lt0 ltzNge lez_eqVlt oppzK=> -> /=. +by case: (_ < _)%Int => //=; rewrite invrK. +qed. + +lemma exprN1 (x : t) : exp x (-1) = invr x. +proof. by rewrite exprN expr1. qed. + +lemma unitrX (x : t) (m : int) : unit x => unit (exp x m). +proof. +move=> invx; wlog: m / (0 <= m) => [wlog|]. ++ (have [] : (0 <= m \/ 0 <= -m) by move=> /#); first by apply: wlog. + by move=> ?; rewrite -oppzK exprN unitrV &(wlog). +elim: m => [|m ge0_m ih]; first by rewrite expr0 unitr1. +by rewrite exprS // &(unitrMl). +qed. + +lemma unitrX_neq0 (x : t) (m : int) : m <> 0 => unit (exp x m) => unit x. +proof. +wlog: m / (0 < m) => [wlog|]. ++ case: (0 < m); [by apply: wlog | rewrite ltzNge /= => le0_m nz_m]. + by move=> h; (apply: (wlog (-m)); 1,2:smt()); rewrite exprN unitrV. +by move=> gt0_m _; rewrite (_ : m = m - 1 + 1) // exprS 1:/# unitrM. +qed. + +lemma exprV (x : t) (i : int) : exp (invr x) i = exp x (-i). +proof. +wlog: i / (0 <= i) => [wlog|]; first by smt(exprN). +elim: i => /= [|i ge0_i ih]; first by rewrite !expr0. +case: (i = 0) => [->|] /=; first by rewrite exprN1 expr1. +move=> nz_i; rewrite exprS // ih !exprN. +case: (unit x) => [invx|invNx]. ++ by rewrite -invrM ?unitrX // exprS // mulrC. +rewrite !invr_out //; last by rewrite exprS. ++ by apply: contra invNx; apply: unitrX_neq0 => /#. ++ by apply: contra invNx; apply: unitrX_neq0 => /#. +qed. + +lemma exprVn (x : t) (n : int) : 0 <= n => exp (invr x) n = invr (exp x n). +proof. +elim: n => [|n ge0_n ih]; first by rewrite !expr0 invr1. +case: (unit x) => ux. +- by rewrite exprSr -1:exprS // invrM ?unitrX // ih -invrM // unitrX. +- by rewrite !invr_out //; apply: contra ux; apply: unitrX_neq0 => /#. +qed. + +lemma exprMn (x y : t) (n : int) : 0 <= n => exp (x * y) n = exp x n * exp y n. +proof. +elim: n => [|n ge0_n ih]; first by rewrite !expr0 mulr1. +by rewrite !exprS // mulrACA ih. +qed. + +lemma exprD_nneg (x : t) (m n : int) : 0 <= m => 0 <= n => + exp x (m + n) = exp x m * exp x n. +proof. +move=> ge0_m ge0_n; elim: m ge0_m => [|m ge0_m ih]. ++ by rewrite expr0 mul1r. +by rewrite addzAC !exprS ?addz_ge0 // ih mulrA. +qed. + +lemma exprD (x : t) (m n : int) : unit x => exp x (m + n) = exp x m * exp x n. +proof. +wlog: m n x / (0 <= m + n) => [wlog invx|]. ++ case: (0 <= m + n); [by move=> ?; apply: wlog | rewrite lezNgt /=]. + move=> lt0_mDn; rewrite -(@oppzK (m + n)) -exprV. + rewrite -{2}(@oppzK m) -{2}(@oppzK n) -!(@exprV _ (- _)%Int). + by rewrite -wlog 1:/# ?unitrV //#. +move=> ge0_mDn invx; wlog: m n ge0_mDn / (m <= n) => [wlog|le_mn]. ++ by case: (m <= n); [apply: wlog | rewrite mulrC addzC /#]. +(have ge0_n: 0 <= n by move=> /#); elim: n ge0_n m le_mn ge0_mDn. ++ by move=> n _ _ /=; rewrite expr0 mulr1. +move=> n ge0_n ih m le_m_Sn ge0_mDSn; move: ge0_mDSn. +rewrite lez_eqVlt => -[?|]; first have->: n+1 = -m by move=> /#. ++ by rewrite subzz exprN expr0 divrr // unitrX. +move=> gt0_mDSn; move: le_m_Sn; rewrite lez_eqVlt. +case=> [->>|lt_m_Sn]; first by rewrite exprD_nneg //#. +by rewrite addzA exprS 1:/# ih 1,2:/# exprS // mulrCA. +qed. + +lemma exprM (x : t) (m n : int) : + exp x (m * n) = exp (exp x m) n. +proof. +wlog : n / 0 <= n. ++ move=> h; case: (0 <= n) => hn; 1: by apply h. + by rewrite -{1}(@oppzK n) (_: m * - -n = -(m * -n)) 1:/# + exprN h 1:/# exprN invrK. +wlog : m / 0 <= m. ++ move=> h; case: (0 <= m) => hm hn; 1: by apply h. + rewrite -{1}(@oppzK m) (_: (- -m) * n = - (-m) * n) 1:/#. + by rewrite exprN h 1:/# // exprN exprV exprN invrK. +elim/natind: n => [|n hn ih hm _]; 1: smt (expr0). +by rewrite mulzDr exprS //= mulrC exprD_nneg 1:/# 1:// ih. +qed. + +lemma expr0n (n : int) : 0 <= n => exp zero<:t> n = if n = 0 then oner else zero. +proof. +elim: n => [|n ge0_n _]; first by rewrite expr0. +by rewrite exprS // mul0r addz1_neq0. +qed. + +lemma expr0z (z : int) : exp zero<:t> z = if z = 0 then oner else zero. +proof. +case: (0 <= z) => [/expr0n // | /ltzNge lt0_z]. +rewrite -{1}(@oppzK z) exprN; have ->/=: z <> 0 by smt(). +rewrite invr_eq0 expr0n ?oppz_ge0 1:ltzW //. +by have ->/=: -z <> 0 by smt(). +qed. + +lemma expr1z (z : int) : exp oner<:t> z = oner. +proof. +elim/intwlog: z. ++ by move=> n h; rewrite -(@oppzK n) exprN h invr1. ++ by rewrite expr0. ++ by move=> n ge0_n ih; rewrite exprS // mul1r ih. +qed. + +(* -------------------------------------------------------------------- *) +(* Squaring identities. *) +lemma sqrrD (x y : t) : + exp (x + y) 2 = exp x 2 + intmul (x * y) 2 + exp y 2. +proof. +by rewrite !expr2 mulrDl !mulrDr mulr2z !addrA (@mulrC y x). +qed. + +lemma sqrrN (x : t) : exp (-x) 2 = exp x 2. +proof. by rewrite !expr2 mulrNN. qed. + +lemma sqrrB (x y : t) : + exp (x - y) 2 = exp x 2 - intmul (x * y) 2 + exp y 2. +proof. by rewrite sqrrD sqrrN mulrN mulNrz. qed. + +lemma signr_odd (n : int) : 0 <= n => + exp (-oner<:t>) (b2i (odd n)) = exp (-oner) n. +proof. +elim: n => [|n ge0_nih]; first by rewrite odd0 expr0 expr0. +rewrite !(iterS, oddS) // exprS // -/(odd _) => <-. +by case: (odd _); rewrite /b2i /= !(expr0, expr1) mulN1r ?opprK. +qed. + +lemma subr_sqr_1 (x : t) : exp x 2 - oner = (x - oner) * (x + oner). +proof. +rewrite mulrBl mulrDr !(mulr1, mul1r) expr2 -addrA. +by congr; rewrite opprD addrA addrN add0r. +qed. + +(* -------------------------------------------------------------------- *) +(* Left regularity: [lreg x] iff multiplication by [x] on the left is + injective. *) +op lreg ['a <: comring] (x : 'a) = injective (fun y => x * y). + +lemma mulrI_eq0 (x y : t) : lreg x => (x * y = zero) <=> (y = zero). +proof. by move=> reg_x; rewrite -{1}(mulr0 x) (inj_eq reg_x). qed. + +lemma lreg_neq0 (x : t) : lreg x => x <> zero. +proof. +apply/contraL=> ->; apply/negP => /(_ zero oner). +by rewrite (@eq_sym _ oner) oner_neq0 /= !mul0r. +qed. + +lemma mulrI0_lreg (x : t) : + (forall y, x * y = zero => y = zero) => lreg x. +proof. +by move=> reg_x y z eq; rewrite -subr_eq0 &(reg_x) mulrBr eq subrr. +qed. + +lemma lregN (x : t) : lreg x => lreg (-x). +proof. by move=> reg_x y z; rewrite !mulNr => /oppr_inj /reg_x. qed. + +lemma lreg1 : lreg oner<:t>. +proof. by move=> x y; rewrite !mul1r. qed. + +lemma lregM (x y : t) : lreg x => lreg y => lreg (x * y). +proof. by move=> reg_x reg_y z t; rewrite -!mulrA => /reg_x /reg_y. qed. + +lemma lregXn (x : t) (n : int) : 0 <= n => lreg x => lreg (exp x n). +proof. +move=> + reg_x; elim: n => [|n ge0_n ih]. +- by rewrite expr0 &(lreg1). +- by rewrite exprS // &(lregM). +qed. + +(* -------------------------------------------------------------------- *) +lemma fracrDE (n1 n2 d1 d2 : t) : + unit d1 => unit d2 => + n1 / d1 + n2 / d2 = (n1 * d2 + n2 * d1) / (d1 * d2). +proof. +move=> inv_d1 inv_d2; rewrite mulrDl [n1 * d2]mulrC. +by rewrite !invrM //; congr; rewrite mulrACA divrr // ?(mul1r, mulr1). +qed. + +end section. +(* ==================================================================== *) +(* Boolean ring: commutative ring with idempotent multiplication. *) +(* ==================================================================== *) +type class boolring <: comring = { + axiom mulrr : forall (x : boolring), x * x = x +}. + +(* -------------------------------------------------------------------- *) +section. +declare type t <: boolring. + +lemma addrr (x : t): x + x = zero. +proof. +apply (@addrI (x + x)); rewrite addr0 -{1 2 3 4}[x]mulrr. +by rewrite -mulrDr -mulrDl mulrr. +qed. + +lemma oppr_id (x : t) : -x = x. +proof. by rewrite -[x]opprK -addr_eq0 opprK addrr. qed. + +end section. + +(* ==================================================================== *) +(* Integral domain: commutative ring with no zero divisors. *) +(* ==================================================================== *) +type class idomain <: comring = { + axiom mulf_eq0 : + forall (x y : idomain), x * y = zero<:idomain> <=> x = zero \/ y = zero +}. + +(* -------------------------------------------------------------------- *) +section. +declare type t <: idomain. + +lemma mulf_neq0 (x y : t) : x <> zero => y <> zero => x * y <> zero. +proof. by move=> nz_x nz_y; apply/negP; rewrite mulf_eq0 /#. qed. + +lemma expf_eq0 (x : t) (n : int) : + (exp x n = zero) <=> (n <> 0 /\ x = zero). +proof. +elim/intwlog: n => [n| |n ge0_n ih]. ++ by rewrite exprN invr_eq0 /#. ++ by rewrite expr0 oner_neq0. +by rewrite exprS // mulf_eq0 ih addz1_neq0 ?andKb. +qed. + +lemma mulfI (x : t) : x <> zero => injective (( * ) x). +proof. +move=> ne0_x y y'; rewrite -(opprK (x * y')) -mulrN -addr_eq0. +by rewrite -mulrDr mulf_eq0 ne0_x /= addr_eq0 opprK. +qed. + +lemma mulIf (x : t) : x <> zero => injective (fun y => y * x). +proof. by move=> nz_x y z; rewrite -!(@mulrC x); exact: mulfI. qed. + +lemma sqrf_eq1 (x : t) : (exp x 2 = oner) <=> (x = oner \/ x = -oner). +proof. by rewrite -subr_eq0 subr_sqr_1 mulf_eq0 subr_eq0 addr_eq0. qed. + +lemma lregP (x : t) : lreg x <=> x <> zero. +proof. by split=> [/lreg_neq0//|/mulfI]. qed. + +lemma eqr_div (x1 y1 x2 y2 : t) : unit y1 => unit y2 => + (x1 / y1 = x2 / y2) <=> (x1 * y2 = x2 * y1). +proof. +move=> Nut1 Nut2; rewrite -{1}(@mulrK y2 _ x1) //. +rewrite -{1}(@mulrK y1 _ x2) // -!mulrA (@mulrC (invr y1)) !mulrA. +split=> [|->] //; + (have nz_Vy1: unit (invr y1) by rewrite unitrV); + (have nz_Vy2: unit (invr y2) by rewrite unitrV). +by move/(mulIr _ nz_Vy1)/(mulIr _ nz_Vy2). +qed. + +end section. + +(* ==================================================================== *) +(* Field: integral domain where every non-zero element is a unit. + The original [Ring.ec] field redefines [unit] via clone-substitution + (`pred unit x <= x <> zeror`); here we keep [unit] as the inherited + predicate and add the equivalence as an axiom of [field]. *) +(* ==================================================================== *) +type class field <: idomain = { + axiom unitfP : forall (x : field), unit x <=> x <> zero<:field> +}. + +(* -------------------------------------------------------------------- *) +section. +declare type t <: field. + +lemma mulfV (x : t) : x <> zero => x * (invr x) = oner. +proof. by move=> nz_x; apply/mulrV/unitfP. qed. + +lemma mulVf (x : t) : x <> zero => (invr x) * x = oner. +proof. by move=> nz_x; apply/mulVr/unitfP. qed. + +lemma divff (x : t) : x <> zero => x / x = oner. +proof. by move=> nz_x; apply/divrr/unitfP. qed. + +lemma invfM (x y : t) : invr (x * y) = invr x * invr y. +proof. +case: (x = zero) => [->|nz_x]; first by rewrite !(mul0r, invr0). +case: (y = zero) => [->|nz_y]; first by rewrite !(mulr0, invr0). +by rewrite invrM ?unitfP // mulrC. +qed. + +lemma invf_div (x y : t) : invr (x / y) = y / x. +proof. by rewrite invfM invrK mulrC. qed. + +lemma eqf_div (x1 y1 x2 y2 : t) : y1 <> zero => y2 <> zero => + (x1 / y1 = x2 / y2) <=> (x1 * y2 = x2 * y1). +proof. by move=> nz_y1 nz_y2; apply: eqr_div; apply/unitfP. qed. + +lemma expfM (x y : t) (n : int) : exp (x * y) n = exp x n * exp y n. +proof. +elim/intwlog: n => [n h | | n ge0_n ih]. ++ by rewrite -(@oppzK n) !(@exprN _ (-n)) h invfM. ++ by rewrite !expr0 mulr1. ++ by rewrite !exprS // mulrCA -!mulrA -ih mulrCA. +qed. + +end section. + +(* ==================================================================== *) +(* Additive morphisms between two [addgroup]s. *) +(* ==================================================================== *) +pred additive ['a <: addgroup, 'b <: addgroup] (f : 'a -> 'b) = + forall (x y : 'a), f (x - y) = f x - f y. + +(* -------------------------------------------------------------------- *) +section. +declare type t1 <: addgroup. +declare type t2 <: addgroup. + +declare op f : t1 -> t2. +declare axiom f_is_additive : additive f. + +lemma raddfB (x y : t1) : f (x - y) = f x - f y. +proof. by apply/f_is_additive. qed. + +lemma raddf0 : f zero<:t1> = zero<:t2>. +proof. by rewrite -(@subr0 zero<:t1>) raddfB subrr. qed. + +lemma raddfN (x : t1) : f (- x) = - (f x). +proof. by rewrite -(@sub0r x) raddfB raddf0 sub0r. qed. + +lemma raddfD (x y : t1) : f (x + y) = f x + f y. +proof. by rewrite -{1}(@opprK y) raddfB raddfN opprK. qed. +end section. + +(* ==================================================================== *) +(* Multiplicative homomorphisms between two [comring]s. *) +(* ==================================================================== *) +pred multiplicative ['a <: comring, 'b <: comring] (f : 'a -> 'b) = + f oner<:'a> = oner<:'b> + /\ forall (x y : 'a), f (x * y) = f x * f y. + +(* ==================================================================== *) +(* Convenience: [(^)] as multiplicative exponentiation on any comring. + Mirrors the [abbrev (^) = exp] declaration in the original + [theories/algebra/Ring.ec:IntID] but is published at top level so + it works for any [comring] carrier (not just [int]). *) +(* ==================================================================== *) +abbrev (^) ['a <: comring] (x : 'a) (n : int) : 'a = exp x n. diff --git a/examples/tcalgebra/sandbox.ec b/examples/tcalgebra/sandbox.ec new file mode 100644 index 0000000000..054768006f --- /dev/null +++ b/examples/tcalgebra/sandbox.ec @@ -0,0 +1,35 @@ +require import AllCore TcMonoid TcRing. + +(* Tvar carrier with multi-parent + factory *) +type class my_comring <: addgroup & (mulmonoid with idm = oner, (+) = mymul) = { + op oner : my_comring + op mymul : my_comring -> my_comring -> my_comring +}. + +section. +declare type t <: my_comring. + +(* Multiplicative side: factory inheritance, abbrev-mediated. *) +lemma test_mulrA : associative ( * )<:t>. +proof. apply addmA. qed. + +lemma test_mulrC : commutative ( * )<:t>. +proof. apply addmC. qed. + +lemma test_mul1r : left_id one<:t> ( * )<:t>. +proof. apply add0m. qed. + +(* Additive side on a multi-parent carrier: [(+)<:t>] is reachable + via two paths to [monoid] (addgroup and mulmonoid-with-renaming), + but only the addgroup path leaves [(+)] unrenamed. Op-name-aware + path resolution should pick that path uniquely. *) +lemma test_addrA : associative (+)<:t>. +proof. apply addmA. qed. + +lemma test_addrC : commutative (+)<:t>. +proof. apply addmC. qed. + +lemma test_add0r : left_id zero<:t> (+)<:t>. +proof. apply add0m. qed. + +end section. diff --git a/examples/tcstdlib/TcBigop.ec b/examples/tcstdlib/TcBigop.ec new file mode 100644 index 0000000000..61c157b49c --- /dev/null +++ b/examples/tcstdlib/TcBigop.ec @@ -0,0 +1,590 @@ +(* This API has been mostly inspired from the [bigop] library of the + * ssreflect Coq extension. *) + +pragma +implicits. + +(* -------------------------------------------------------------------- *) +require import AllCore List Ring TcMonoid. + +import Ring.IntID. + +(* -------------------------------------------------------------------- *) +section. +declare type t <: monoid. + +(* -------------------------------------------------------------------- *) +op big (P : 'a -> bool) (F : 'a -> t) (r : 'a list) = + foldr (+) idm (map F (filter P r)). + +(* -------------------------------------------------------------------- *) +abbrev bigi (P : int -> bool) (F : int -> t) i j = + big P F (range i j). + +(* -------------------------------------------------------------------- *) +lemma big_nil (P : 'a -> bool) (F : 'a -> t): big P F [] = idm. +proof. by []. qed. + +(* -------------------------------------------------------------------- *) +lemma big_cons (P : 'a -> bool) (F : 'a -> t) x s: + big P F (x :: s) = if P x then F x + big P F s else big P F s. +proof. by rewrite {1}/big /= (@fun_if (map F)); case (P x). qed. + +lemma big_consT (F : 'a -> t) x s: + big predT F (x :: s) = F x + big predT F s. +proof. by apply/big_cons. qed. + +(* -------------------------------------------------------------------- *) +lemma big_rec (K : t -> bool) r P (F : 'a -> t): + K idm => (forall i x, P i => K x => K (F i + x)) => K (big P F r). +proof. + move=> K0 Kop; elim: r => //= i r; rewrite big_cons. + by case (P i) => //=; apply/Kop. +qed. + +lemma big_ind (K : t -> bool) r P (F : 'a -> t): + (forall x y, K x => K y => K (x + y)) + => K idm => (forall i, P i => K (F i)) + => K (big P F r). +proof. + move=> Kop Kidx K_F; apply/big_rec => //. + by move=> i x Pi Kx; apply/Kop => //; apply/K_F. +qed. + +lemma big_rec2: + forall (K : t -> t -> bool) r P (F1 F2 : 'a -> t), + K idm idm + => (forall i y1 y2, P i => K y1 y2 => K (F1 i + y1) (F2 i + y2)) + => K (big P F1 r) (big P F2 r). +proof. + move=> K r P F1 F2 KI KF; elim: r => //= i r IHr. + by rewrite !big_cons; case (P i) => ? //=; apply/KF. +qed. + +lemma big_ind2: + forall (K : t -> t -> bool) r P (F1 F2 : 'a -> t), + (forall x1 x2 y1 y2, K x1 x2 => K y1 y2 => K (x1 + y1) (x2 + y2)) + => K idm idm + => (forall i, P i => K (F1 i) (F2 i)) + => K (big P F1 r) (big P F2 r). +proof. + move=> K r P F1 F2 Kop KI KF; apply/big_rec2 => //. + by move=> i x1 x2 Pi Kx1x2; apply/Kop => //; apply/KF. +qed. + +(* -------------------------------------------------------------------- *) +lemma big_endo (f : t -> t): + f idm = idm + => (forall (x y : t), f (x + y) = f x + f y) + => forall r P (F : 'a -> t), + f (big P F r) = big P (f \o F) r. +proof. + (* FIXME: should be a consequence of big_morph *) + move=> fI fM; elim=> //= i r IHr P F; rewrite !big_cons. + by case (P i) => //=; rewrite 1?fM IHr. +qed. + +(* -------------------------------------------------------------------- *) +lemma big_map ['a 'b] (h : 'b -> 'a) (P : 'a -> bool) F s: + big P F (map h s) = big (P \o h) (F \o h) s. +proof. by elim: s => // x s; rewrite map_cons !big_cons=> ->. qed. + +lemma big_mapT ['a 'b] (h : 'b -> 'a) F s: (* -> big_map_predT *) + big predT F (map h s) = big predT (F \o h) s. +proof. by rewrite big_map. qed. + +(* -------------------------------------------------------------------- *) +lemma big_comp ['a] (h : t -> t) (P : 'a -> bool) F s: + h idm = idm => morphism_2 h (+) (+) => + h (big P F s) = big P (h \o F) s. +proof. + move=> Hidm Hh;elim: s => // x s; rewrite !big_cons => <-. + by rewrite /(\o) -Hh;case (P x) => //. +qed. + +(* -------------------------------------------------------------------- *) +lemma big_nth x0 (P : 'a -> bool) (F : 'a -> t) s: + big P F s = bigi (P \o (nth x0 s)) (F \o (nth x0 s)) 0 (size s). +proof. by rewrite -{1}(@mkseq_nth x0 s) /mkseq big_map. qed. + +(* -------------------------------------------------------------------- *) +lemma big_const (P : 'a -> bool) x s: + big P (fun i => x) s = iter (count P s) ((+) x) idm. +proof. + elim: s=> [|y s ih]; [by rewrite iter0 | rewrite big_cons /=]. + by rewrite ih; case (P y) => //; rewrite addzC iterS // count_ge0. +qed. + +(* -------------------------------------------------------------------- *) +lemma big_seq1 (F : 'a -> t) x: big predT F [x] = F x. +proof. by rewrite big_cons big_nil addm0. qed. + +(* -------------------------------------------------------------------- *) +lemma big_mkcond (P : 'a -> bool) (F : 'a -> t) s: + big P F s = big predT (fun i => if P i then F i else idm) s. +proof. + elim: s=> // x s ih; rewrite !big_cons -ih /predT /=. + by case (P x)=> //; rewrite add0m. +qed. + +(* -------------------------------------------------------------------- *) +lemma big_filter (P : 'a -> bool) F s: + big predT F (filter P s) = big P F s. +proof. by elim: s => //= x s; case (P x)=> //; rewrite !big_cons=> -> ->. qed. + +(* -------------------------------------------------------------------- *) +lemma big_filter_cond (P1 P2 : 'a -> bool) F s: + big P2 F (filter P1 s) = big (predI P1 P2) F s. +proof. by rewrite -big_filter -(@big_filter _ _ s) predIC filter_predI. qed. + +(* -------------------------------------------------------------------- *) +lemma eq_bigl (P1 P2 : 'a -> bool) (F : 'a -> t) s: + (forall i, P1 i <=> P2 i) + => big P1 F s = big P2 F s. +proof. by move=> h; rewrite /big (eq_filter h). qed. + +(* -------------------------------------------------------------------- *) +lemma eq_bigr (P : 'a -> bool) (F1 F2 : 'a -> t) s: + (forall i, P i => F1 i = F2 i) + => big P F1 s = big P F2 s. +proof. (* FIXME: big_rec2 *) + move=> eqF; elim: s=> // x s; rewrite !big_cons=> <-. + by case (P x)=> // /eqF <-. +qed. + +(* -------------------------------------------------------------------- *) +lemma big_distrl ['a] (op_ : t -> t -> t) (P : 'a -> bool) F s u: + left_zero idm op_ + => left_distributive op_ (+) + => op_ (big P F s) u = big P (fun a => op_ (F a) u) s. +proof. + move=> mulm1 mulmDl; pose G := fun x => op_ x u. + move: (big_comp G P) => @/G /= -> //. + by rewrite mulm1. by move=> t1 t2; rewrite mulmDl. +qed. + +lemma big_distrr ['a] (op_ : t -> t -> t) (P : 'a -> bool) F s u: + right_zero idm op_ + => right_distributive op_ (+) + => op_ u (big P F s) = big P (fun a => op_ u (F a)) s. +proof. + move=> mul1m mulmDr; pose G := fun x => op_ u x. + move: (big_comp G P) => @/G /= -> //. + by rewrite mul1m. by move=> t1 t2; rewrite mulmDr. +qed. + +lemma big_distr ['a 'b] (op_ : t -> t -> t) + (P1 : 'a -> bool) (P2 : 'b -> bool) F1 s1 F2 s2 : + commutative op_ + => left_zero idm op_ + => left_distributive op_ (+) + => op_ (big P1 F1 s1) (big P2 F2 s2) = + big P1 (fun a1 => big P2 (fun a2 => op_ (F1 a1) (F2 a2)) s2) s1. +proof. + move=> mulmC mulm1 mulmDl; rewrite big_distrl //. + apply/eq_bigr=> i _ /=; rewrite big_distrr //. + by move=> x; rewrite mulmC mulm1. + by move=> x y z; rewrite !(mulmC x) mulmDl. +qed. + +(* -------------------------------------------------------------------- *) +lemma big_andbC (P Q : 'a -> bool) (F : 'a -> t) s: + big (fun x => P x /\ Q x) F s = big (fun x => Q x /\ P x) F s. +proof. by apply/eq_bigl=> i. qed. + +(* -------------------------------------------------------------------- *) +lemma eq_big (P1 P2 : 'a -> bool) (F1 F2 : 'a -> t) s: + (forall i, P1 i <=> P2 i) + => (forall i, P1 i => F1 i = F2 i) + => big P1 F1 s = big P2 F2 s. +proof. by move=> /eq_bigl <- /eq_bigr <-. qed. + +(* -------------------------------------------------------------------- *) +lemma congr_big r1 r2 P1 P2 (F1 F2 : 'a -> t): + r1 = r2 + => (forall x, P1 x <=> P2 x) + => (forall i, P1 i => F1 i = F2 i) + => big P1 F1 r1 = big P2 F2 r2. +proof. by move=> <-; apply/eq_big. qed. + +(* -------------------------------------------------------------------- *) +lemma big_hasC (P : 'a -> bool) (F : 'a -> t) s: !has P s => + big P F s = idm. +proof. + rewrite -big_filter has_count -size_filter. + by rewrite ltz_def size_ge0 /= => /size_eq0 ->. +qed. + +(* -------------------------------------------------------------------- *) +lemma big_pred0_eq (F : 'a -> t) s: big pred0 F s = idm. +proof. by rewrite big_hasC // has_pred0. qed. + +(* -------------------------------------------------------------------- *) +lemma big_pred0 (P : 'a -> bool) (F : 'a -> t) s: + (forall i, P i <=> false) => big P F s = idm. +proof. by move=> h; rewrite -(@big_pred0_eq F s); apply/eq_bigl. qed. + +(* -------------------------------------------------------------------- *) +lemma big_cat (P : 'a -> bool) (F : 'a -> t) s1 s2: + big P F (s1 ++ s2) = big P F s1 + big P F s2. +proof. + rewrite !(@big_mkcond P); elim: s1 => /= [|i s1 ih]. + by rewrite (@big_nil P F) add0m. + by rewrite !big_cons /(predT i) /= ih addmA. +qed. + +(* -------------------------------------------------------------------- *) +lemma big_catl (P : 'a -> bool) (F : 'a -> t) s1 s2: !has P s2 => + big P F (s1 ++ s2) = big P F s1. +proof. by rewrite big_cat => /big_hasC ->; rewrite addm0. qed. + +(* -------------------------------------------------------------------- *) +lemma big_catr (P : 'a -> bool) (F : 'a -> t) s1 s2: !has P s1 => + big P F (s1 ++ s2) = big P F s2. +proof. by rewrite big_cat => /big_hasC ->; rewrite add0m. qed. + +(* -------------------------------------------------------------------- *) +lemma big_rcons (P : 'a -> bool) (F : 'a -> t) s x: + big P F (rcons s x) = if P x then big P F s + F x else big P F s. +proof. + by rewrite -cats1 big_cat big_cons big_nil; case: (P x); rewrite !addm0. +qed. + +(* -------------------------------------------------------------------- *) +lemma eq_big_perm (P : 'a -> bool) (F : 'a -> t) s1 s2: + perm_eq s1 s2 => big P F s1 = big P F s2. +proof. + move=> /perm_eqP; rewrite !(@big_mkcond P). + elim s1 s2 => [|i s1 ih1] s2 eq_s12. + + case: s2 eq_s12=> // i s2 h. + by have := h (pred1 i)=> //=; smt(count_ge0). + have r2i: mem s2 i by rewrite -has_pred1 has_count -eq_s12 #smt:(count_ge0). + have/splitPr [s3 s4] ->> := r2i. + rewrite big_cat !big_cons /(predT i) /=. + rewrite addmCA; congr; rewrite -big_cat; apply/ih1=> a. + by have := eq_s12 a; rewrite !count_cat /= addzCA => /addzI. +qed. + +(* -------------------------------------------------------------------- *) +lemma eq_big_perm_map (F : 'a -> t) s1 s2: + perm_eq (map F s1) (map F s2) => big predT F s1 = big predT F s2. +proof. +by move=> peq; rewrite -!(@big_map F predT idfun) &(eq_big_perm). +qed. + +(* -------------------------------------------------------------------- *) +lemma big_seq_cond (P : 'a -> bool) (F : 'a -> t) s: + big P F s = big (fun i => mem s i /\ P i) F s. +proof. by rewrite -!(@big_filter _ _ s); congr; apply/eq_in_filter. qed. + +(* -------------------------------------------------------------------- *) +lemma big_seq (F : 'a -> t) s: + big predT F s = big (fun i => mem s i) F s. +proof. by rewrite big_seq_cond; apply/eq_bigl. qed. + +(* -------------------------------------------------------------------- *) +lemma big_rem (P : 'a -> bool) (F : 'a -> t) s x: mem s x => + big P F s = (if P x then F x else idm) + big P F (rem x s). +proof. + by move/perm_to_rem/eq_big_perm=> ->; rewrite !(@big_mkcond P) big_cons. +qed. + +(* -------------------------------------------------------------------- *) +lemma bigD1 (F : 'a -> t) s x: mem s x => uniq s => + big predT F s = F x + big (predC1 x) F s. +proof. by move=> /big_rem-> /rem_filter->; rewrite big_filter. qed. + +(* -------------------------------------------------------------------- *) +lemma bigD1_cond P (F : 'a -> t) s x: P x => mem s x => uniq s => + big P F s = F x + big (predI P (predC1 x)) F s. +proof. +move=> Px sx uqs; rewrite -big_filter (@bigD1 _ _ x) ?big_filter_cond //. + by rewrite mem_filter Px. by rewrite filter_uniq. +qed. + +(* -------------------------------------------------------------------- *) +lemma bigD1_cond_if P (F : 'a -> t) s x: uniq s => big P F s = + (if mem s x /\ P x then F x else idm) + big (predI P (predC1 x)) F s. +proof. +case: (mem s x /\ P x) => [[Px sx]|Nsx]; rewrite ?add0m /=. + by apply/bigD1_cond. +move=> uqs; rewrite big_seq_cond eq_sym big_seq_cond; apply/eq_bigl=> i /=. +by case: (i = x) => @/predC1 @/predI [->>|]. +qed. + +(* -------------------------------------------------------------------- *) +lemma big_split (P : 'a -> bool) (F1 F2 : 'a -> t) s: + big P (fun i => F1 i + F2 i) s = big P F1 s + big P F2 s. +proof. + elim: s=> /= [|x s ih]; 1: by rewrite !big_nil addm0. + rewrite !big_cons ih; case: (P x) => // _. + by rewrite addmCA -!addmA addmCA. +qed. + +(* -------------------------------------------------------------------- *) +lemma bigID (P : 'a -> bool) (F : 'a -> t) (a : 'a -> bool) s: + big P F s = big (predI P a) F s + big (predI P (predC a)) F s. +proof. +rewrite !(@big_mkcond _ F) -big_split; apply/eq_bigr => i _ /=. +by rewrite /predI /predC; case: (a i); rewrite ?addm0 ?add0m. +qed. + +(* -------------------------------------------------------------------- *) +lemma bigU ['a] (P Q : 'a -> bool) (F : 'a -> t) s : (forall x, !(P x /\ Q x)) => + big (predU P Q) F s = big P F s + big Q F s. +proof. +move=> dj_PQ; rewrite (@bigID (predU _ _) _ P). +by congr; apply: eq_bigl => /#. +qed. + +(* -------------------------------------------------------------------- *) +lemma bigEM ['a] (P : 'a -> bool) (F : 'a -> t) s : + big predT F s = big P F s + big (predC P) F s. +proof. by rewrite -bigU 1:/#; apply: eq_bigl => /#. qed. + +(* -------------------------------------------------------------------- *) +lemma big_reindex ['a 'b] + (P : 'a -> bool) (F : 'a -> t) (f : 'b -> 'a) (f' : 'a -> 'b) (s : 'a list) : + (forall x, x \in s => f (f' x) = x) + => big P F s = big (P \o f) (F \o f) (map f' s). +proof. +by move => /eq_in_map id_ff'; rewrite -big_map -map_comp id_ff' id_map. +qed. + +(* -------------------------------------------------------------------- *) +lemma big_pair_pswap ['a 'b] (p : 'a * 'b -> bool) (f : 'a * 'b -> t) s : + big<:'a * 'b> p f s + = big<:'b * 'a> (p \o pswap) (f \o pswap) (map pswap s). +proof. by apply/big_reindex; case. qed. + +(* -------------------------------------------------------------------- *) +lemma eq_big_seq (F1 F2 : 'a -> t) s: + (forall x, mem s x => F1 x = F2 x) + => big predT F1 s = big predT F2 s. +proof. by move=> eqF; rewrite !big_seq; apply/eq_bigr. qed. + +(* -------------------------------------------------------------------- *) +lemma congr_big_seq (P1 P2: 'a -> bool) (F1 F2 : 'a -> t) s: + (forall x, mem s x => P1 x = P2 x) => + (forall x, mem s x => P1 x => P2 x => F1 x = F2 x) + => big P1 F1 s = big P2 F2 s. +proof. + move=> eqP eqH; rewrite big_mkcond eq_sym big_mkcond eq_sym. + apply/eq_big_seq=> x x_in_s /=; rewrite eqP //. + by case (P2 x)=> // P2x; rewrite eqH // eqP. +qed. + +(* -------------------------------------------------------------------- *) +lemma big1_eq (P : 'a -> bool) s: big P (fun (x : 'a) => idm) s = idm. +proof. + rewrite big_const; elim/natind: (count _ _)=> n. + by move/iter0<:t> => ->. + by move/iterS<:t> => -> ->; rewrite addm0. +qed. + +(* -------------------------------------------------------------------- *) +lemma big1 (P : 'a -> bool) (F : 'a -> t) s: + (forall i, P i => F i = idm) => big P F s = idm. +proof. by move/eq_bigr=> ->; apply/big1_eq. qed. + +(* -------------------------------------------------------------------- *) +lemma big1_seq (P : 'a -> bool) (F : 'a -> t) s: + (forall i, P i /\ (mem s i) => F i = idm) => big P F s = idm. +proof. by move=> eqF1; rewrite big_seq_cond big_andbC big1. qed. + +(* -------------------------------------------------------------------- *) +lemma big_eq_idm_filter ['a] (P : 'a -> bool) (F : 'a -> t) s : + (forall (x : 'a), !P x => F x = idm) => big predT F s = big P F s. +proof. +by move=> eq1; rewrite (@bigEM P) (@big1 (predC _)) // addm0. +qed. + +(* -------------------------------------------------------------------- *) +lemma big_flatten (P : 'a -> bool) (F : 'a -> t) rr : + big P F (flatten rr) = big predT (fun s => big P F s) rr. +proof. +elim: rr => /= [|r rr ih]; first by rewrite !big_nil. +by rewrite flatten_cons big_cat big_cons -ih. +qed. + +(* -------------------------------------------------------------------- *) +lemma big_pair ['a 'b] (F : 'a * 'b -> t) (s : ('a * 'b) list) : uniq s => + big predT F s = + big predT (fun a => + big predT F (filter (fun xy : _ * _ => xy.`1 = a) s)) + (undup (map fst s)). +proof. +move=> /perm_eq_pair /eq_big_perm /(_ predT F) ->. +by rewrite big_flatten big_map. +qed. + +(* -------------------------------------------------------------------- *) +lemma big_nseq_cond (P : 'a -> bool) (F : 'a -> t) n x : + big P F (nseq n x) = if P x then iter n ((+) (F x)) idm else idm. +proof. +elim/natind: n => [n le0_n|n ge0_n ih]; first by rewrite ?(nseq0_le, iter0). +by rewrite nseqS // big_cons ih; case: (P x) => //; rewrite iterS. +qed. + +(* -------------------------------------------------------------------- *) +lemma big_nseq (F : 'a -> t) n x : + big predT F (nseq n x) = iter n ((+) (F x)) idm. +proof. by apply/big_nseq_cond. qed. + +(* -------------------------------------------------------------------- *) +lemma big_undup ['a] (P : 'a -> bool) (F : 'a -> t) s : + big P F s = big P (fun a => iter (count (pred1 a) s) ((+) (F a)) idm) (undup s). +proof. +have <- := eq_big_perm P F _ _ (perm_undup_count s). +rewrite big_flatten big_map (@big_mkcond P); apply/eq_big => //=. +by move=> @/(\o) /= x _; apply/big_nseq_cond. +qed. + +(* -------------------------------------------------------------------- *) +lemma exchange_big (P1 : 'a -> bool) (P2 : 'b -> bool) (F : 'a -> 'b -> t) s1 s2: + big P1 (fun a => big P2 (F a) s2) s1 = + big P2 (fun b => big P1 (fun a => F a b) s1) s2. +proof. + elim: s1 s2 => [|a s1 ih] s2; first by rewrite big_nil big1_eq. + rewrite big_cons ih; case: (P1 a)=> h; rewrite -?big_split; + by apply/eq_bigr=> x _ /=; rewrite big_cons h. +qed. + +(* -------------------------------------------------------------------- *) +lemma partition_big ['a 'b] (px : 'a -> 'b) P Q (F : 'a -> t) s s' : + uniq s' + => (forall x, mem s x => P x => mem s' (px x) /\ Q (px x)) + => big P F s = big Q (fun x => big (fun y => P y /\ px y = x) F s) s'. +proof. +move=> uq_s'; elim: s => /~= [|x xs ih] hm. + by rewrite big_nil big1_eq. +rewrite big_cons; case: (P x) => /= [Px|PxN]; last first. + rewrite ih //; 1: by move=> y y_xs; apply/hm; rewrite y_xs. + by apply/eq_bigr=> i _ /=; rewrite big_cons /= PxN. +have := hm x; rewrite Px /= => -[s'_px Qpx]; apply/eq_sym. +rewrite (@bigD1_cond _ _ _ (px x)) //= big_cons /= Px /=. +rewrite -addmA; congr; apply/eq_sym; rewrite ih. + by move=> y y_xs; apply/hm; rewrite y_xs. +rewrite (@bigD1_cond _ _ _ (px x)) //=; congr. +apply/eq_bigr=> /= i [Qi @/predC1]; rewrite eq_sym => ne_pxi. +by rewrite big_cons /= ne_pxi. +qed. + +(* -------------------------------------------------------------------- *) +lemma big_allpairs (f : 'a -> 'b -> 'c) (F : 'c -> t) s u: + big predT F (allpairs<:'a, 'b, 'c> f s u) + = big predT (fun x => big predT (fun y => F (f x y)) u) s. +proof. +elim: s u => [|x s ih] u //=. +by rewrite allpairs_consl big_cat ih big_consT big_map. +qed. + +(* -------------------------------------------------------------------- *) +lemma big_int_cond m n P (F : int -> t): + bigi P F m n = bigi (fun i => m <= i < n /\ P i) F m n. +proof. by rewrite big_seq_cond; apply/eq_bigl=> i /=; rewrite mem_range. qed. + +(* -------------------------------------------------------------------- *) +lemma big_int m n (F : int -> t): + bigi predT F m n = bigi (fun i => m <= i < n) F m n. +proof. by rewrite big_int_cond. qed. + +(* -------------------------------------------------------------------- *) +lemma congr_big_int (m1 n1 m2 n2 : int) P1 P2 (F1 F2 : int -> t): + m1 = m2 => n1 = n2 + => (forall i, m1 <= i < n2 => P1 i = P2 i) + => (forall i, P1 i /\ (m1 <= i < n2) => F1 i = F2 i) + => bigi P1 F1 m1 n1 = bigi P2 F2 m2 n2. +proof. + move=> <- <- eqP12 eqF12; rewrite big_seq_cond (@big_seq_cond P2). + by apply/eq_big=> i /=; rewrite mem_range #smt:(). +qed. + +(* -------------------------------------------------------------------- *) +lemma eq_big_int (m n : int) (F1 F2 : int -> t): + (forall i, m <= i < n => F1 i = F2 i) + => bigi predT F1 m n = bigi predT F2 m n. +proof. by move=> eqF; apply/congr_big_int. qed. + +(* -------------------------------------------------------------------- *) +lemma big_ltn_cond (m n : int) P (F : int -> t): m < n => + let x = bigi P F (m+1) n in + bigi P F m n = if P m then F m + x else x. +proof. by move/range_ltn=> ->; rewrite big_cons. qed. + +(* -------------------------------------------------------------------- *) +lemma big_ltn (m n : int) (F : int -> t): m < n => + bigi predT F m n = F m + bigi predT F (m+1) n. +proof. by move/big_ltn_cond=> /= ->. qed. + +(* -------------------------------------------------------------------- *) +lemma big_geq (m n : int) P (F : int -> t): n <= m => + bigi P F m n = idm. +proof. by move/range_geq=> ->; rewrite big_nil. qed. + +(* -------------------------------------------------------------------- *) +lemma big_addn (m n a : int) P (F : int -> t): + bigi P F (m+a) n + = bigi (fun i => P (i+a)) (fun i => F (i+a)) m (n-a). +proof. +rewrite range_addl big_map; apply/eq_big. + by move=> i /=; rewrite /(\o) addzC. +by move=> i /= _; rewrite /(\o) addzC. +qed. + +(* -------------------------------------------------------------------- *) +lemma big_int1 n (F : int -> t): bigi predT F n (n+1) = F n. +proof. by rewrite big_ltn 1:/# big_geq // addm0. qed. + +(* -------------------------------------------------------------------- *) +lemma big_cat_int (n m p : int) P (F : int -> t): m <= n => n <= p => + bigi P F m p = (bigi P F m n) + (bigi P F n p). +proof. by move=> lemn lenp; rewrite -big_cat -range_cat. qed. + +(* -------------------------------------------------------------------- *) +lemma big_int_recl (n m : int) (F : int -> t): m <= n => + bigi predT F m (n+1) = F m + bigi predT (fun i => F (i+1)) m n. +proof. by move=> lemn; rewrite big_ltn 1?big_addn /= 1:/#. qed. + +(* -------------------------------------------------------------------- *) +lemma big_int_recr (n m : int) (F : int -> t): m <= n => + bigi predT F m (n+1) = bigi predT F m n + F n. +proof. by move=> lemn; rewrite (@big_cat_int n) ?big_int1 //#. qed. + +(* -------------------------------------------------------------------- *) +lemma big_int_recl_cond (n m : int) P (F : int -> t): m <= n => + bigi P F m (n+1) = + (if P m then F m else idm) + + bigi (fun i => P (i+1)) (fun i => F (i+1)) m n. +proof. +by move=> lemn; rewrite big_mkcond big_int_recl //= -big_mkcond. +qed. + +(* -------------------------------------------------------------------- *) +lemma big_int_recr_cond (n m : int) P (F : int -> t): m <= n => + bigi P F m (n+1) = + bigi P F m n + (if P n then F n else idm). +proof. by move=> lemn; rewrite !(@big_mkcond P) big_int_recr. qed. + +(* -------------------------------------------------------------------- *) +lemma bigi_split_odd_even (n : int) (F : int -> t) : 0 <= n => + bigi predT (fun i => F (2 * i) + F (2 * i + 1)) 0 n + = bigi predT F 0 (2 * n). +proof. +move=> ge0_n; rewrite big_split; pose rg := range 0 n. +rewrite -(@big_mapT (fun i => 2 * i)). +rewrite -(@big_mapT (fun i => 2 * i + 1)). +rewrite -big_cat &(eq_big_perm) &(uniq_perm_eq) 2:&(range_uniq). +- rewrite cat_uniq !map_inj_in_uniq /= ~-1:/# range_uniq /=. + apply/hasPn => _ /mapP[y] /= [_ ->]. + by apply/negP; case/mapP=> ? [_] /#. +move=> x; split. +- rewrite mem_cat; case=> /mapP[y] /=; + case=> /mem_range y_rg -> {x}; apply/mem_range; + by smt(). +move/mem_range => x_rg; rewrite mem_cat. +have: forall (i : int), exists j, i = 2 * j \/ i = 2 * j + 1 by smt(). +- case/(_ x) => y [] ->>; [left | right]; apply/mapP=> /=; + by exists y; (split; first apply/mem_range); smt(). +qed. + +end section. diff --git a/examples/tcstdlib/TcMonoid.ec b/examples/tcstdlib/TcMonoid.ec new file mode 100644 index 0000000000..f33a9da550 --- /dev/null +++ b/examples/tcstdlib/TcMonoid.ec @@ -0,0 +1,35 @@ +require import Int. + +(* -------------------------------------------------------------------- *) +type class monoid = { + op idm : monoid + op (+) : monoid -> monoid -> monoid + + axiom addmA: associative (+) + axiom addmC: commutative (+) + axiom add0m: left_id idm (+) +}. + +(* -------------------------------------------------------------------- *) +section. +declare type m <: monoid. + +lemma addm0: right_id idm (+)<:m>. +proof. by move=> x; rewrite addmC add0m. qed. + +lemma addmCA: left_commutative (+)<:m>. +proof. by move=> x y z; rewrite !addmA (addmC x). qed. + +lemma addmAC: right_commutative (+)<:m>. +proof. by move=> x y z; rewrite -!addmA (addmC y). qed. + +lemma addmACA: interchange (+)<:m> (+). +proof. by move=> x y z t; rewrite -!addmA (addmCA y). qed. + +lemma iteropE n (x : m): iterop n (+) x idm = iter n ((+) x) idm. +proof. +elim/natcase n => [n le0_n|n ge0_n]. ++ by rewrite ?(iter0, iterop0). ++ by rewrite iterSr // addm0 iteropS. +qed. +end section. diff --git a/examples/tcstdlib/TcRing.ec b/examples/tcstdlib/TcRing.ec new file mode 100644 index 0000000000..6f7a589834 --- /dev/null +++ b/examples/tcstdlib/TcRing.ec @@ -0,0 +1,858 @@ +pragma +implicits. + +(* -------------------------------------------------------------------- *) +require import Core Int TcMonoid. + +(* -------------------------------------------------------------------- *) +type class group <: monoid = { + op [ - ] : group -> group + + axiom addNr: left_inverse idm [-] (+)<:group> +}. + +section. +declare type g <: group. + +abbrev zeror = idm<:g>. +abbrev ( - ) (x y : g) = x + -y. + +(* -------------------------------------------------------------------- *) +lemma addrA: associative (+)<:g>. +proof. by exact: addmA. qed. + +lemma addrC: commutative (+)<:g>. +proof. by exact: addmC. qed. + +lemma add0r: left_id zeror (+)<:g>. +proof. by exact: add0m. qed. + +(* -------------------------------------------------------------------- *) +lemma addr0: right_id zeror (+)<:g>. +proof. by move=> x; rewrite addrC add0r. qed. + +lemma addrN: right_inverse zeror [-] (+)<:g>. +proof. by move=> x; rewrite addrC addNr. qed. + +lemma addrCA: left_commutative (+)<:g>. +proof. by move=> x y z; rewrite !addrA (@addrC x y). qed. + +lemma addrAC: right_commutative (+)<:g>. +proof. by move=> x y z; rewrite -!addrA (@addrC y z). qed. + +lemma addrACA: interchange (+)<:g> (+)<:g>. +proof. by move=> x y z t; rewrite -!addrA (addrCA y). qed. + +lemma subrr (x : g): x - x = zeror. +proof. by rewrite addrN. qed. + +lemma addKr: left_loop [-] (+)<:g>. +proof. by move=> x y; rewrite addrA addNr add0r. qed. + +lemma addNKr: rev_left_loop [-] (+)<:g>. +proof. by move=> x y; rewrite addrA addrN add0r. qed. + +lemma addrK: right_loop [-] (+)<:g>. +proof. by move=> x y; rewrite -addrA addrN addr0. qed. + +lemma addrNK: rev_right_loop [-] (+)<:g>. +proof. by move=> x y; rewrite -addrA addNr addr0. qed. + +lemma subrK (x y : g): (x - y) + y = x. +proof. by rewrite addrNK. qed. + +lemma addrI: right_injective (+)<:g>. +proof. by move=> x y z h; rewrite -(@addKr x z) -h addKr. qed. + +lemma addIr: left_injective (+)<:g>. +proof. by move=> x y z h; rewrite -(@addrK x z) -h addrK. qed. + +lemma opprK: involutive [-]<:g>. +proof. by move=> x; apply (@addIr (-x)); rewrite addNr addrN. qed. + +lemma oppr_inj : injective [-]<:g>. +proof. by move=> x y eq; apply/(addIr (-x)); rewrite subrr eq subrr. qed. + +lemma oppr0 : -zeror = zeror. +proof. by rewrite -(@addr0 (-zeror)) addNr. qed. + +lemma oppr_eq0 (x : g) : (- x = zeror) <=> (x = zeror). +proof. by rewrite (inv_eq opprK) oppr0. qed. + +lemma subr0 (x : g): x - zeror = x. +proof. by rewrite oppr0 addr0. qed. + +lemma sub0r (x : g): zeror - x = - x. +proof. by rewrite add0r. qed. + +lemma opprD (x y : g): -(x + y) = -x + -y. +proof. by apply (@addrI (x + y)); rewrite addrA addrN addrAC addrK addrN. qed. + +lemma opprB (x y : g): -(x - y) = y - x. +proof. by rewrite opprD opprK addrC. qed. + +lemma subrACA: interchange (-) (+)<:g>. +proof. by move=> x y z t; rewrite addrACA opprD. qed. + +lemma subr_eq (x y z : g): + (x - z = y) <=> (x = y + z). +proof. +move: (can2_eq (fun x, x - z) (fun x, x + z) _ _ x y) => //=. ++ by move=> {x} x /=; rewrite addrNK. ++ by move=> {x} x /=; rewrite addrK. +qed. + +lemma subr_eq0 (x y : g): (x - y = zeror) <=> (x = y). +proof. by rewrite subr_eq add0r. qed. + +lemma addr_eq0 (x y : g): (x + y = zeror) <=> (x = -y). +proof. by rewrite -(@subr_eq0 x) opprK. qed. + +lemma eqr_opp (x y : g): (- x = - y) <=> (x = y). +proof. by apply/(@can_eq _ _ opprK x y). qed. + +lemma eqr_oppLR (x y : g) : (- x = y) <=> (x = - y). +proof. by apply/(@inv_eq _ opprK x y). qed. + +lemma eqr_sub (x y z t : g) : (x - y = z - t) <=> (x + t = z + y). +proof. +rewrite -{1}(addrK t x) -{1}(addrK y z) -!addrA. +by rewrite (addrC (-t)) !addrA; split=> [/addIr /addIr|->//]. +qed. + +lemma subr_add2r (z x y : g): (x + z) - (y + z) = x - y. +proof. by rewrite opprD addrACA addrN addr0. qed. + +op intmul (x : g) (n : int) = + (* (signz n) * (iterop `|n| ZModule.(+) x zeror) *) + if n < 0 + then -(iterop (-n) (+)<:g> x zeror) + else (iterop n (+)<:g> x zeror). + +lemma intmulpE (z : g) c : 0 <= c => + intmul z c = iterop c (+)<:g> z zeror. +proof. by rewrite /intmul lezNgt => ->. qed. + +lemma mulr0z (x : g): intmul x 0 = zeror. +proof. by rewrite /intmul /= iterop0. qed. + +lemma mulr1z (x : g): intmul x 1 = x. +proof. by rewrite /intmul /= iterop1. qed. + +lemma mulr2z (x : g): intmul x 2 = x + x. +proof. by rewrite /intmul /= (@iteropS 1) // (@iterS 0) // iter0. qed. + +lemma mulrNz (x : g) (n : int): intmul x (-n) = -(intmul x n). +proof. +case: (n = 0)=> [->|nz_c]; first by rewrite oppz0 mulr0z oppr0. +rewrite /intmul oppz_lt0 oppzK ltz_def nz_c lezNgt /=. +by case: (n < 0); rewrite ?opprK. +qed. + +lemma mulrS (x : g) (n : int): 0 <= n => + intmul x (n+1) = x + intmul x n. +proof. +move=> ge0n; rewrite !intmulpE 1:addz_ge0 //. +by rewrite !iteropE iterS. +qed. + +lemma mulNrz (x : g) n : intmul (- x) n = - (intmul x n). +proof. +elim/intwlog: n => [n h| | n ge0_n ih]. ++ by rewrite -(@oppzK n) !(@mulrNz _ (- n)) h. ++ by rewrite !mulr0z oppr0. ++ by rewrite !mulrS // ih opprD. +qed. + +lemma mulNrNz (x : g) (n : int) : intmul (-x) (-n) = intmul x n. +proof. by rewrite mulNrz mulrNz opprK. qed. + +lemma mulrSz (x : g) n : intmul x (n + 1) = x + intmul x n. +proof. +case: (0 <= n) => [/mulrS ->//|]; rewrite -ltzNge => gt0_n. +case: (n = -1) => [->/=|]; 1: by rewrite mulrNz mulr1z mulr0z subrr. +move=> neq_n_N1; rewrite -!(@mulNrNz x). +rewrite (_ : -n = -(n+1) + 1) 1:/# mulrS 1:/#. +by rewrite addrA subrr add0r. +qed. + +lemma mulrDz (x : g) (n m : int) : intmul x (n + m) = intmul x n + intmul x m. +proof. +wlog: n m / 0 <= m => [wlog|]. ++ case: (0 <= m) => [/wlog|]; first by apply. + rewrite -ltzNge => lt0_m; rewrite (_ : n + m = -(-m - n)) 1:/#. + by rewrite mulrNz addzC wlog 1:/# !mulrNz -opprD opprK. +elim: m => /= [|m ge0_m ih]; first by rewrite mulr0z addr0. +by rewrite addzA !mulrSz ih addrCA. +qed. + +end section. + +(* -------------------------------------------------------------------- *) +type class comring <: group = { + op oner : comring + op ( * ) : comring -> comring -> comring + op invr : comring -> comring + op unit : comring -> bool + + axiom oner_neq0 : oner <> zeror + axiom mulrA : associative ( * ) + axiom mulrC : commutative ( * ) + axiom mul1r : left_id oner ( * ) + axiom mulrDl : left_distributive ( * ) (+)<:comring> + axiom mulVr : left_inverse_in unit oner invr ( * ) + axiom unitP : forall (x y : comring), y * x = oner => unit x + axiom unitout : forall (x : comring), !unit x => invr x = x +}. + +section. +declare type r <: comring. + +instance monoid with r + op idm = oner<:r> + op (+) = ( * )<:r>. +realize addmA by exact: mulrA. +realize addmC by exact: mulrC. +realize add0m by exact: mul1r. + +abbrev ( / ) (x y : r) = x * (invr y). + +lemma mulr1: right_id oner ( * )<:r>. +proof. by move=> x; rewrite mulrC mul1r. qed. + +lemma mulrCA: left_commutative ( * )<:r>. +proof. by move=> x y z; rewrite !mulrA (@mulrC x y). qed. + +lemma mulrAC: right_commutative ( * )<:r>. +proof. by move=> x y z; rewrite -!mulrA (@mulrC y z). qed. + +lemma mulrACA: interchange ( * ) ( * )<:r>. +proof. by move=> x y z t; rewrite -!mulrA (mulrCA y). qed. + +lemma mulrSl (x y : r) : (x + oner) * y = x * y + y. +proof. by rewrite mulrDl mul1r. qed. + +lemma mulrDr: right_distributive ( * ) (+)<:r>. +proof. by move=> x y z; rewrite mulrC mulrDl !(@mulrC _ x). qed. + +lemma mul0r: left_zero zeror ( * )<:r>. +proof. by move=> x; apply: (@addIr (oner * x)); rewrite -mulrDl !add0r mul1r. qed. + +lemma mulr0: right_zero zeror ( * )<:r>. +proof. by move=> x; apply: (@addIr (x * oner)); rewrite -mulrDr !add0r mulr1. qed. + +lemma mulrN (x y : r): x * (- y) = - (x * y). +proof. by apply: (@addrI (x * y)); rewrite -mulrDr !addrN mulr0. qed. + +lemma mulNr (x y : r): (- x) * y = - (x * y). +proof. by apply: (@addrI (x * y)); rewrite -mulrDl !addrN mul0r. qed. + +lemma mulrNN (x y : r): (- x) * (- y) = x * y. +proof. by rewrite mulrN mulNr opprK. qed. + +lemma mulN1r (x : r): (-oner) * x = -x. +proof. by rewrite mulNr mul1r. qed. + +lemma mulrN1 (x : r): x * -oner = -x. +proof. by rewrite mulrN mulr1. qed. + +lemma mulrBl: left_distributive ( * ) (-)<:r>. +proof. by move=> x y z; rewrite mulrDl !mulNr. qed. + +lemma mulrBr: right_distributive ( * ) (-)<:r>. +proof. by move=> x y z; rewrite mulrDr !mulrN. qed. + +lemma mulrnAl (x y : r) n : 0 <= n => (intmul x n) * y = intmul (x * y) n. +proof. +elim: n => [|n ge0n ih]; rewrite !(mulr0z, mulrS) ?mul0r //. +by rewrite mulrDl ih. +qed. + +lemma mulrnAr (x y : r) n : 0 <= n => x * (intmul y n) = intmul (x * y) n. +proof. +elim: n => [|n ge0n ih]; rewrite !(mulr0z, mulrS) ?mulr0 //. +by rewrite mulrDr ih. +qed. + +lemma mulrzAl (x y : r) z : (intmul x z) * y = intmul (x * y) z. +proof. +case: (lezWP 0 z)=> [|_] le; first by rewrite mulrnAl. +by rewrite -oppzK mulrNz mulNr mulrnAl -?mulrNz // oppz_ge0. +qed. + +lemma mulrzAr x (y : r) z : x * (intmul y z) = intmul (x * y) z. +proof. +case: (lezWP 0 z)=> [|_] le; first by rewrite mulrnAr. +by rewrite -oppzK mulrNz mulrN mulrnAr -?mulrNz // oppz_ge0. +qed. + +lemma mulrV: right_inverse_in unit oner invr ( * )<:r>. +proof. by move=> x /mulVr; rewrite mulrC. qed. + +lemma divrr (x : r): unit x => x / x = oner. +proof. by apply/mulrV. qed. + +lemma invr_out (x : r): !unit x => invr x = x. +proof. by apply/unitout. qed. + +lemma unitrP (x : r): unit x <=> (exists y, y * x = oner). +proof. by split=> [/mulVr<- |]; [exists (invr x) | case=> y /unitP]. qed. + +lemma mulKr: left_loop_in unit invr ( * )<:r>. +proof. by move=> x un_x y; rewrite mulrA mulVr // mul1r. qed. + +lemma mulrK: right_loop_in unit invr ( * )<:r>. +proof. by move=> y un_y x; rewrite -mulrA mulrV // mulr1. qed. + +lemma mulVKr: rev_left_loop_in unit invr ( * )<:r>. +proof. by move=> x un_x y; rewrite mulrA mulrV // mul1r. qed. + +lemma mulrVK: rev_right_loop_in unit invr ( * )<:r>. +proof. by move=> y nz_y x; rewrite -mulrA mulVr // mulr1. qed. + +lemma mulrI: right_injective_in unit ( * )<:r>. +proof. by move=> x Ux; have /can_inj h := mulKr _ Ux. qed. + +lemma mulIr: left_injective_in unit ( * )<:r>. +proof. by move=> x /mulrI h y1 y2; rewrite !(@mulrC _ x) => /h. qed. + +lemma unitrE (x : r): unit x <=> (x / x = oner). +proof. +split=> [Ux|xx1]; 1: by apply/divrr. +by apply/unitrP; exists (invr x); rewrite mulrC. +qed. + +lemma invrK: involutive invr<:r>. +proof. +move=> x; case: (unit x)=> Ux; 2: by rewrite !invr_out. +rewrite -(mulrK _ Ux (invr (invr x))) -mulrA. +rewrite (@mulrC x) mulKr //; apply/unitrP. +by exists x; rewrite mulrV. +qed. + +lemma invr_inj: injective invr<:r>. +proof. by apply: (can_inj _ _ invrK). qed. + +lemma unitrV (x : r): unit (invr x) <=> unit x. +proof. by rewrite !unitrE invrK mulrC. qed. + +lemma unitr1: unit oner<:r>. +proof. by apply/unitrP; exists oner; rewrite mulr1. qed. + +lemma invr1: invr oner = oner<:r>. +proof. by rewrite -{2}(mulVr _ unitr1) mulr1. qed. + +lemma div1r x: oner / x = invr x. +proof. by rewrite mul1r. qed. + +lemma divr1 x: x / oner = x. +proof. by rewrite invr1 mulr1. qed. + +lemma unitr0: !unit zeror<:r>. +proof. by apply/negP=> /unitrP [y]; rewrite mulr0 eq_sym oner_neq0. qed. + +lemma invr0: invr zeror = zeror<:r>. +proof. by rewrite invr_out ?unitr0. qed. + +lemma unitrN1: unit (-oner<:r>). +proof. by apply/unitrP; exists (-oner); rewrite mulrNN mulr1. qed. + +lemma invrN1: invr (-oner) = -oner<:r>. +proof. by rewrite -{2}(divrr unitrN1) mulN1r opprK. qed. + +lemma unitrMl (x y : r) : unit y => (unit (x * y) <=> unit x). +proof. (* FIXME: wlog *) +move=> uy; case: (unit x)=> /=; last first. + apply/contra=> uxy; apply/unitrP; exists (y * invr (x * y)). + apply/(mulrI (invr y)); first by rewrite unitrV. + rewrite !mulrA mulVr // mul1r; apply/(mulIr y)=> //. + by rewrite -mulrA mulVr // mulr1 mulVr. +move=> ux; apply/unitrP; exists (invr y * invr x). +by rewrite -!mulrA mulKr // mulVr. +qed. + +lemma unitrMr (x y : r): unit x => (unit (x * y) <=> unit y). +proof. +move=> ux; split=> [uxy|uy]; last by rewrite unitrMl. +by rewrite -(mulKr _ ux y) unitrMl ?unitrV. +qed. + +lemma unitrM (x y : r) : unit (x * y) <=> (unit x /\ unit y). +proof. +case: (unit x) => /=; first by apply: unitrMr. +apply: contra => /unitrP[z] zVE; apply/unitrP. +by exists (y * z); rewrite mulrAC (@mulrC y) (@mulrC _ z). +qed. + +lemma unitrN (x : r) : unit (-x) <=> unit x. +proof. by rewrite -mulN1r unitrMr // unitrN1. qed. + +lemma invrM (x y : r) : unit x => unit y => invr (x * y) = invr y * invr x. +proof. +move=> Ux Uy; have Uxy: unit (x * y) by rewrite unitrMl. +by apply: (mulrI _ Uxy); rewrite mulrV ?mulrA ?mulrK ?mulrV. +qed. + +lemma invrN (x : r) : invr (- x) = - (invr x). +proof. +case: (unit x) => ux; last by rewrite !invr_out ?unitrN. +by rewrite -mulN1r invrM ?unitrN1 // invrN1 mulrN1. +qed. + +lemma invr_neq0 (x : r) : x <> zeror => invr x <> zeror. +proof. +move=> nx0; case: (unit x)=> Ux; last by rewrite invr_out ?Ux. +by apply/negP=> x'0; move: Ux; rewrite -unitrV x'0 unitr0. +qed. + +lemma invr_eq0 (x : r) : (invr x = zeror) <=> (x = zeror). +proof. by apply/iff_negb; split=> /invr_neq0; rewrite ?invrK. qed. + +lemma invr_eq1 (x : r) : (invr x = oner) <=> (x = oner). +proof. by rewrite (inv_eq invrK) invr1. qed. + +op ofint n = intmul oner<:r> n. + +lemma ofint0: ofint 0 = zeror. +proof. by apply/mulr0z. qed. + +lemma ofint1: ofint 1 = oner. +proof. by apply/mulr1z. qed. + +lemma ofintS (i : int): 0 <= i => ofint (i+1) = oner + ofint i. +proof. by apply/mulrS. qed. + +lemma ofintN (i : int): ofint (-i) = - (ofint i). +proof. by apply/mulrNz. qed. + +lemma mul1r0z x: x * ofint 0 = zeror. +proof. by rewrite ofint0 mulr0. qed. + +lemma mul1r1z x : x * ofint 1 = x. +proof. by rewrite ofint1 mulr1. qed. + +lemma mul1r2z x : x * ofint 2 = x + x. +proof. by rewrite /ofint mulr2z mulrDr mulr1. qed. + +lemma mulr_intl x z : (ofint z) * x = intmul x z. +proof. by rewrite mulrzAl mul1r. qed. + +lemma mulr_intr x z : x * (ofint z) = intmul x z. +proof. by rewrite mulrzAr mulr1. qed. + +op exp (x : r) (n : int) = + if n < 0 + then invr (iterop (-n) ( * ) x oner) + else iterop n ( * ) x oner. + +lemma expr0 x: exp x 0 = oner. +proof. by rewrite /exp /= iterop0. qed. + +lemma expr1 x: exp x 1 = x. +proof. by rewrite /exp /= iterop1. qed. + +lemma exprS (x : r) i: 0 <= i => exp x (i+1) = x * (exp x i). +proof. +move=> ge0i; rewrite /exp !ltzNge ge0i addz_ge0 //=. +(* we want to use the multiplicative monoid instance here *) +(* by rewrite !Monoid.iteropE iterS. *) admit. +qed. + +lemma expr_pred (x : r) i : 0 < i => exp x i = x * (exp x (i - 1)). +proof. smt(exprS). qed. + +lemma exprSr (x : r) i: 0 <= i => exp x (i+1) = (exp x i) * x. +proof. by move=> ge0_i; rewrite exprS // mulrC. qed. + +lemma expr2 x: exp x 2 = x * x. +proof. by rewrite (@exprS _ 1) // expr1. qed. + +lemma exprN (x : r) (i : int): exp x (-i) = invr (exp x i). +proof. +case: (i = 0) => [->|]; first by rewrite oppz0 expr0 invr1. +rewrite /exp oppz_lt0 ltzNge lez_eqVlt oppzK=> -> /=. +by case: (_ < _)%Int => //=; rewrite invrK. +qed. + +lemma exprN1 (x : r) : exp x (-1) = invr x. +proof. by rewrite exprN expr1. qed. + +lemma unitrX x m : unit x => unit (exp x m). +proof. +move=> invx; wlog: m / (0 <= m) => [wlog|]. ++ (have [] : (0 <= m \/ 0 <= -m) by move=> /#); first by apply: wlog. + by move=> ?; rewrite -oppzK exprN unitrV &(wlog). +elim: m => [|m ge0_m ih]; first by rewrite expr0 unitr1. +by rewrite exprS // &(unitrMl). +qed. + +lemma unitrX_neq0 x m : m <> 0 => unit (exp x m) => unit x. +proof. +wlog: m / (0 < m) => [wlog|]. ++ case: (0 < m); [by apply: wlog | rewrite ltzNge /= => le0_m nz_m]. + by move=> h; (apply: (wlog (-m)); 1,2:smt()); rewrite exprN unitrV. +by move=> gt0_m _; rewrite (_ : m = m - 1 + 1) // exprS 1:/# unitrM. +qed. + +lemma exprV (x : r) (i : int): exp (invr x) i = exp x (-i). +proof. +wlog: i / (0 <= i) => [wlog|]; first by smt(exprN). +elim: i => /= [|i ge0_i ih]; first by rewrite !expr0. +case: (i = 0) => [->|] /=; first by rewrite exprN1 expr1. +move=> nz_i; rewrite exprS // ih !exprN. +case: (unit x) => [invx|invNx]. ++ by rewrite -invrM ?unitrX // exprS // mulrC. +rewrite !invr_out //; last by rewrite exprS. ++ by apply: contra invNx; apply: unitrX_neq0 => /#. ++ by apply: contra invNx; apply: unitrX_neq0 => /#. +qed. + +lemma exprVn (x : r) (n : int) : 0 <= n => exp (invr x) n = invr (exp x n). +proof. +elim: n => [|n ge0_n ih]; first by rewrite !expr0 invr1. +case: (unit x) => ux. +- by rewrite exprSr -1:exprS // invrM ?unitrX // ih -invrM // unitrX. +- by rewrite !invr_out //; apply: contra ux; apply: unitrX_neq0 => /#. +qed. + +lemma exprMn (x y : r) (n : int) : 0 <= n => exp (x * y) n = exp x n * exp y n. +proof. +elim: n => [|n ge0_n ih]; first by rewrite !expr0 mulr1. +by rewrite !exprS // mulrACA ih. +qed. + +lemma exprD_nneg x (m n : int) : 0 <= m => 0 <= n => + exp x (m + n) = exp x m * exp x n. +proof. + move=> ge0_m ge0_n; elim: m ge0_m => [|m ge0_m ih]. + by rewrite expr0 mul1r. + by rewrite addzAC !exprS ?addz_ge0 // ih mulrA. +qed. + +lemma exprD x (m n : int) : unit x => exp x (m + n) = exp x m * exp x n. +proof. +wlog: m n x / (0 <= m + n) => [wlog invx|]. ++ case: (0 <= m + n); [by move=> ?; apply: wlog | rewrite lezNgt /=]. + move=> lt0_mDn; rewrite -(@oppzK (m + n)) -exprV. + rewrite -{2}(@oppzK m) -{2}(@oppzK n) -!(@exprV _ (- _)%Int). + by rewrite -wlog 1:/# ?unitrV //#. +move=> ge0_mDn invx; wlog: m n ge0_mDn / (m <= n) => [wlog|le_mn]. ++ by case: (m <= n); [apply: wlog | rewrite mulrC addzC /#]. +(have ge0_n: 0 <= n by move=> /#); elim: n ge0_n m le_mn ge0_mDn. ++ by move=> n _ _ /=; rewrite expr0 mulr1. +move=> n ge0_n ih m le_m_Sn ge0_mDSn; move: ge0_mDSn. +rewrite lez_eqVlt => -[?|]; first have->: n+1 = -m by move=> /#. ++ by rewrite subzz exprN expr0 divrr // unitrX. +move=> gt0_mDSn; move: le_m_Sn; rewrite lez_eqVlt. +case=> [->>|lt_m_Sn]; first by rewrite exprD_nneg //#. +by rewrite addzA exprS 1:/# ih 1,2:/# exprS // mulrCA. +qed. + +lemma exprM x (m n : int) : + exp x (m * n) = exp (exp x m) n. +proof. +wlog : n / 0 <= n. ++ move=> h; case: (0 <= n) => hn; 1: by apply h. + by rewrite -{1}(@oppzK n) (_: m * - -n = -(m * -n)) 1:/# + exprN h 1:/# exprN invrK. +wlog : m / 0 <= m. ++ move=> h; case: (0 <= m) => hm hn; 1: by apply h. + rewrite -{1}(@oppzK m) (_: (- -m) * n = - (-m) * n) 1:/#. + by rewrite exprN h 1:/# // exprN exprV exprN invrK. +elim/natind: n => [|n hn ih hm _]; 1: smt (expr0). +by rewrite mulzDr exprS //= mulrC exprD_nneg 1:/# 1:// ih. +qed. + +lemma expr0n n : 0 <= n => exp zeror n = if n = 0 then oner else zeror. +proof. +elim: n => [|n ge0_n _]; first by rewrite expr0. +by rewrite exprS // mul0r addz1_neq0. +qed. + +lemma expr0z z : exp zeror z = if z = 0 then oner else zeror. +proof. +case: (0 <= z) => [/expr0n // | /ltzNge lt0_z]. +rewrite -{1}(@oppzK z) exprN; have ->/=: z <> 0 by smt(). +rewrite invr_eq0 expr0n ?oppz_ge0 1:ltzW //. +by have ->/=: -z <> 0 by smt(). +qed. + +lemma expr1z z : exp oner z = oner. +proof. +elim/intwlog: z. ++ by move=> n h; rewrite -(@oppzK n) exprN h invr1. ++ by rewrite expr0. ++ by move=> n ge0_n ih; rewrite exprS // mul1r ih. +qed. + +lemma sqrrD (x y : r) : + exp (x + y) 2 = exp x 2 + intmul (x * y) 2 + exp y 2. +proof. +by rewrite !expr2 mulrDl !mulrDr mulr2z !addrA (@mulrC y x). +qed. + +lemma sqrrN x : exp (-x) 2 = exp x 2. +proof. by rewrite !expr2 mulrNN. qed. + +lemma sqrrB x y : + exp (x - y) 2 = exp x 2 - intmul (x * y) 2 + exp y 2. +proof. by rewrite sqrrD sqrrN mulrN mulNrz. qed. + +lemma signr_odd n : 0 <= n => exp (-oner) (b2i (odd n)) = exp (-oner) n. +proof. +elim: n => [|n ge0_nih]; first by rewrite odd0 expr0 expr0. +rewrite !(iterS, oddS) // exprS // -/(odd _) => <-. +by case: (odd _); rewrite /b2i /= !(expr0, expr1) mulN1r ?opprK. +qed. + +lemma subr_sqr_1 x : exp x 2 - oner = (x - oner) * (x + oner). +proof. +rewrite mulrBl mulrDr !(mulr1, mul1r) expr2 -addrA. +by congr; rewrite opprD addrA addrN add0r. +qed. + +op lreg (x : r) = injective (fun y => x * y). + +lemma mulrI_eq0 x y : lreg x => (x * y = zeror) <=> (y = zeror). +proof. by move=> reg_x; rewrite -{1}(mulr0 x) (inj_eq reg_x). qed. + +lemma lreg_neq0 x : lreg x => x <> zeror. +proof. +apply/contraL=> ->; apply/negP => /(_ zeror oner). +by rewrite (@eq_sym _ oner) oner_neq0 /= !mul0r. +qed. + +lemma mulrI0_lreg x : (forall y, x * y = zeror => y = zeror) => lreg x. +proof. +by move=> reg_x y z eq; rewrite -subr_eq0 &(reg_x) mulrBr eq subrr. +qed. + +lemma lregN x : lreg x => lreg (-x). +proof. by move=> reg_x y z; rewrite !mulNr => /oppr_inj /reg_x. qed. + +lemma lreg1 : lreg oner. +proof. by move=> x y; rewrite !mul1r. qed. + +lemma lregM x y : lreg x => lreg y => lreg (x * y). +proof. by move=> reg_x reg_y z t; rewrite -!mulrA => /reg_x /reg_y. qed. + +lemma lregXn x n : 0 <= n => lreg x => lreg (exp x n). +proof. +move=> + reg_x; elim: n => [|n ge0_n ih]. +- by rewrite expr0 &(lreg1). +- by rewrite exprS // &(lregM). +qed. +end section. + +(* +(* -------------------------------------------------------------------- *) +abstract theory ComRingDflInv. + clone include ComRing with + pred unit (x : t) = exists y, y * x = oner, + op invr (x : t) = choiceb (fun y => y * x = oner) x + + proof mulVr, unitP, unitout. + + realize mulVr. + proof. + move=> x ^ un_x [y ^ -> <-] @/invr_. + by have /= -> := choicebP _ x un_x. + qed. + + realize unitP. + proof. by move=> x y eq; exists y. qed. + + realize unitout. + proof. + by move=> x; rewrite /unit_ negb_exists => /choiceb_dfl /(_ x). + qed. +end ComRingDflInv. +*) + +(* -------------------------------------------------------------------- *) +type class boolring <: comring = { + axiom mulrr : forall (x : boolring), x * x = x +}. + +lemma addrr ['a <: boolring] (x : 'a): x + x = zeror. +proof. +apply (@addrI (x + x)); rewrite addr0 -{1 2 3 4}[x]mulrr. +by rewrite -mulrDr -mulrDl mulrr. +qed. + +(* -------------------------------------------------------------------- *) +type class idomain <: comring = { + axiom mulf_eq0: + forall (x y : idomain), x * y = zeror <=> x = zeror \/ y = zeror +}. + +section. +declare type r <: idomain. + +lemma mulf_neq0 (x y : r): x <> zeror => y <> zeror => x * y <> zeror. +proof. by move=> nz_x nz_y; apply/negP; rewrite mulf_eq0 /#. qed. + +lemma expf_eq0 (x : r) n : (exp x n = zeror) <=> (n <> 0 /\ x = zeror). +proof. +elim/intwlog: n => [n| |n ge0_n ih]. ++ by rewrite exprN invr_eq0 /#. ++ by rewrite expr0 oner_neq0. +by rewrite exprS // mulf_eq0 ih addz1_neq0 ?andKb. +qed. + +lemma mulfI (x : r): x <> zeror => injective (( * ) x). +proof. +move=> ne0_x y y'; rewrite -(opprK (x * y')) -mulrN -addr_eq0. +by rewrite -mulrDr mulf_eq0 ne0_x /= addr_eq0 opprK. +qed. + +lemma mulIf (x : r): x <> zeror => injective (fun y => y * x). +proof. by move=> nz_x y z; rewrite -!(@mulrC x); exact: mulfI. qed. + +lemma sqrf_eq1 (x : r): (exp x 2 = oner) <=> (x = oner \/ x = -oner). +proof. by rewrite -subr_eq0 subr_sqr_1 mulf_eq0 subr_eq0 addr_eq0. qed. + +lemma lregP (x : r): lreg x <=> x <> zeror. +proof. by split=> [/lreg_neq0//|/mulfI]. qed. + +lemma eqr_div (x1 y1 x2 y2 : r) : unit y1 => unit y2 => + (x1 / y1 = x2 / y2) <=> (x1 * y2 = x2 * y1). +proof. +move=> Nut1 Nut2; rewrite -{1}(@mulrK y2 _ x1) //. +rewrite -{1}(@mulrK y1 _ x2) // -!mulrA (@mulrC (invr y1)) !mulrA. +split=> [|->] //; + (have nz_Vy1: unit (invr y1) by rewrite unitrV); + (have nz_Vy2: unit (invr y2) by rewrite unitrV). +by move/(mulIr _ nz_Vy1)/(mulIr _ nz_Vy2). +qed. +end section. + +(* -------------------------------------------------------------------- *) +(* +(* TODO: Disjointness of type class operator names? *) +type class ffield <: group = { + op onef : ffield + op ( * ) : ffield -> ffield -> ffield + op invf : ffield -> ffield + + axiom onef_neq0 : onef <> zeror + axiom mulfA : associative ( * ) + axiom mulfC : commutative ( * ) + axiom mul1f : left_id onef ( * ) + axiom mulfDl : left_distributive ( * ) (+)<:ffield> + axiom mulVf : left_inverse_in (predC (pred1 zeror)) onef invf ( * ) + axiom unitP : forall (x y : ffield), y * x = onef => x <> zeror + axiom unitout : invr zeror = zeror +}. +*) + +(* TODO: Probably not the right way *) +type class ffield <: idomain = { + axiom unit_neq0: forall (x : ffield), unit x <=> x <> zeror +}. + +section. +declare type f <: ffield. + +lemma mulfV (x : f): x <> zeror => x * (invr x) = oner. +proof. by move=> /unit_neq0/mulrV. qed. + +lemma mulVf (x : f): x <> zeror => (invr x) * x = oner. +proof. by move=> /unit_neq0/mulVr. qed. + +lemma divff (x : f): x <> zeror => x / x = oner. +proof. by move=> /unit_neq0/divrr. qed. + +lemma invfM (x y : f) : invr (x * y) = invr x * invr y. +proof. +case: (x = zeror) => [->|nz_x]; first by rewrite !(mul0r, invr0). +case: (y = zeror) => [->|nz_y]; first by rewrite !(mulr0, invr0). +by rewrite invrM ?unit_neq0 // mulrC. +qed. + +lemma invf_div (x y : f) : invr (x / y) = y / x. +proof. by rewrite invfM invrK mulrC. qed. + +lemma eqf_div (x1 y1 x2 y2 : f) : y1 <> zeror => y2 <> zeror => + (x1 / y1 = x2 / y2) <=> (x1 * y2 = x2 * y1). +proof. by rewrite -!unit_neq0; exact: eqr_div<:f>. qed. + +lemma expfM (x y : f) n : exp (x * y) n = exp x n * exp y n. +proof. +elim/intwlog: n => [n h | | n ge0_n ih]. ++ by rewrite -(@oppzK n) !(@exprN _ (-n)) h invfM. ++ by rewrite !expr0 mulr1. ++ by rewrite !exprS // mulrCA -!mulrA -ih mulrCA. +qed. +end section. + +(* --------------------------------------------------------------------- *) +(* Rewrite database for algebra tactic *) + +hint rewrite rw_algebra : . +hint rewrite inj_algebra : . + +(* -------------------------------------------------------------------- *) +(* TODO: Instantiation of type classes with inheritance is broken *) +(* TODO: Instantiation of type class operators with literals is broken *) +op zeroz = 0. +op addz (x y : int) = x + y. +op negz (x : int) = -x. + + +instance monoid with int + op idm = zeroz + op (+) = addz. +realize addmA by exact: addzA. +realize addmC by exact: addzC. +realize add0m by exact: add0z. + +(* TODO: This is just broken *) +instance group with int + (* op idm = zeroz *) + op [-] = negz. +realize addNr. +(* TODO: Note that the zero remains undefined *) +rewrite /left_inverse /negz /idm. +(* by exact: addNz. *) admit. + +(* +theory IntID. +clone include IDomain with + type t <- int, + pred unit (z : int) <- (z = 1 \/ z = -1), + op zeror <- 0, + op oner <- 1, + op ( + ) <- Int.( + ), + op [ - ] <- Int.([-]), + op ( * ) <- Int.( * ), + op invr <- (fun (z : int) => z) + proof * by smt + remove abbrev (-) + remove abbrev (/) + rename "ofint" as "ofint_id". + +abbrev (^) = exp. + +lemma intmulz z c : intmul z c = z * c. +proof. +have h: forall cp, 0 <= cp => intmul z cp = z * cp. + elim=> /= [|cp ge0_cp ih]; first by rewrite mulr0z. + by rewrite mulrS // ih mulrDr /= addrC. +smt(opprK mulrNz opprK). +qed. + +lemma poddX n x : 0 < n => odd (exp x n) = odd x. +proof. +rewrite ltz_def => - [] + ge0_n; elim: n ge0_n => // + + _ _. +elim=> [|n ge0_n ih]; first by rewrite expr1. +by rewrite exprS ?addz_ge0 // oddM ih andbb. +qed. + +lemma oddX n x : 0 <= n => odd (exp x n) = (odd x \/ n = 0). +proof. +rewrite lez_eqVlt; case: (n = 0) => [->// _|+ h]. ++ by rewrite expr0 odd1. ++ by case: h => [<-//|] /poddX ->. +qed. +end IntID. +*) diff --git a/examples/typeclasses/monoidtc.ec b/examples/typeclasses/monoidtc.ec new file mode 100644 index 0000000000..a892abbcb5 --- /dev/null +++ b/examples/typeclasses/monoidtc.ec @@ -0,0 +1,54 @@ +require import Int. + +(* -------------------------------------------------------------------- *) +type class addmonoid = { + op idm : addmonoid + op (+) : addmonoid -> addmonoid -> addmonoid + + axiom addmA : associative (+) + axiom addmC : commutative (+) + axiom add0m : left_id idm (+) +}. + +(* -------------------------------------------------------------------- *) +lemma addm0 ['a <: addmonoid] : right_id idm (+)<:'a>. +proof. by move=> x; rewrite addmC add0m. qed. + +lemma addmCA ['a <: addmonoid] : left_commutative (+)<:'a>. +proof. by move=> x y z; rewrite !addmA (addmC x). qed. + +lemma addmAC ['a <: addmonoid] : right_commutative (+)<:'a>. +proof. by move=> x y z; rewrite -!addmA (addmC y). qed. + +lemma addmACA ['a <: addmonoid] : interchange (+)<:'a> (+)<:'a>. +proof. by move=> x y z t; rewrite -!addmA (addmCA y). qed. + +lemma iteropE ['a <: addmonoid] n x: iterop n (+)<:'a> x idm<:'a> = iter n ((+)<:'a> x) idm<:'a>. +proof. + elim/natcase n => [n le0_n|n ge0_n]. + + by rewrite ?(iter0, iterop0). + + by rewrite iterSr // addm0 iteropS. +qed. + +(* -------------------------------------------------------------------- *) +abstract theory AddMonoid. + type t. + + op idm : t. + op (+) : t -> t -> t. + + theory Axioms. + axiom addmA: associative AddMonoid.(+). + axiom addmC: commutative AddMonoid.(+). + axiom add0m: left_id AddMonoid.idm AddMonoid.(+). + end Axioms. + + instance addmonoid with t + op idm = idm + op (+) = (+). + + realize addmA by exact Axioms.addmA. + realize addmC by exact Axioms.addmC. + realize add0m by exact Axioms.add0m. + +end AddMonoid. diff --git a/examples/typeclasses/typeclass.ec b/examples/typeclasses/typeclass.ec new file mode 100644 index 0000000000..eaee3603cf --- /dev/null +++ b/examples/typeclasses/typeclass.ec @@ -0,0 +1,353 @@ +(* ==================================================================== *) +(* Typeclass examples *) + +(* -------------------------------------------------------------------- *) +require import AllCore List. + +(* -------------------------------------------------------------------- *) +(* Set theory *) + +type class ['a] artificial = { + op myop : artificial * 'a +}. + +op myopi ['a] : int * 'a = (0, witness<:'a>). + +instance 'b artificial with ['b] int + op myop = myopi<:'b>. + +lemma reduce_tc : myop<:bool, int> = (0, witness). +proof. +class. +reflexivity. +qed. + +(* -------------------------------------------------------------------- *) +type class witness = { + op witness : witness +}. + +print witness. + +type class finite = { + op enum : finite list + axiom enumP : forall (x : finite), x \in enum +}. + +print enum. +print enumP. + +type class countable = { + op count : int -> countable + axiom countP : forall (x : countable), exists (n : int), x = count n +}. + +(* -------------------------------------------------------------------- *) +(* Simple algebraic structures *) + +type class magma = { + op mmul : magma -> magma -> magma +}. + +print mmul. + +type class semigroup <: magma = { + axiom mmulA : associative mmul<:semigroup> +}. + +print associative. + +type class monoid <: semigroup = { + op mid : monoid + + axiom mmulr0 : right_id mid mmul<:monoid> + axiom mmul0r : left_id mid mmul<:monoid> +}. + +type class group <: monoid = { + op minv : group -> group + + axiom mmulN : left_inverse mid minv mmul +}. + +type class ['a <: semigroup] semigroup_action = { + op amul : 'a -> semigroup_action -> semigroup_action + + axiom compatibility : + forall (g h : 'a) (x : semigroup_action), amul (mmul g h) x = amul g (amul h x) +}. + +type class ['a <: monoid] monoid_action <: 'a semigroup_action = { + axiom identity : forall (x : monoid_action), amul mid<:'a> x = x +}. + +(* TODO: why again is this not possible/a good idea? *) +(*type class finite_group <: group & finite = {}.*) + +(* -------------------------------------------------------------------- *) +(* Advanced algebraic structures *) + +type class comgroup = { + op zero : comgroup + op ([-]) : comgroup -> comgroup + op ( + ) : comgroup -> comgroup -> comgroup + + axiom addr0 : right_id zero ( + ) + axiom addrN : left_inverse zero ([-]) ( + ) + axiom addrC : commutative ( + ) + axiom addrA : associative ( + ) +}. + +type class comring <: comgroup = { + op one : comring + op ( * ) : comring -> comring -> comring + + axiom mulr1 : right_id one ( * ) + axiom mulrC : commutative ( * ) + axiom mulrA : associative ( * ) + axiom mulrDl : left_distributive ( * ) ( + ) +}. + +type class ['a <: comring] commodule <: comgroup = { + op ( ** ) : 'a -> commodule -> commodule + + axiom scalerDl : forall (a b : 'a) (x : commodule), + (a + b) ** x = (a ** x) + (b ** x) + axiom scalerDr : forall (a : 'a) (x y : commodule), + a ** (x + y) = (a ** x) + (a ** y) +}. + + +(* ==================================================================== *) +(* Abstract type examples *) + +(* TODO: finish the hierarchy here: + https://en.wikipedia.org/wiki/Magma_(algebra) *) +type foo <: witness. +type fingroup <: group & finite. + + + +(* TODO: printing typeclasses *) +print countable. +print magma. +print semigroup. +print monoid. +print group. +print semigroup_action. +print monoid_action. + + +(* ==================================================================== *) +(* Operator examples *) + +(* -------------------------------------------------------------------- *) +(* Set theory *) + +op all_finite ['a <: finite] (p : 'a -> bool) = + all p enum<:'a>. + +op all_countable ['a <: countable] (p : 'a -> bool) = + forall (n : int), p (count<:'a> n). + +(* -------------------------------------------------------------------- *) +(* Simple algebraic structures *) + +(* TODO: weird issue and/or inapropriate error message : bug in ecUnify select_op*) + +print amul. +(* +op foo1 ['a <: semigroup, 'b <: 'a semigroup_action] = amul<:'a,'b>. +*) +op foo2 ['a <: semigroup, 'b <: 'a semigroup_action] (g : 'a) (x : 'b) = amul g x. +(* +op foo3 ['a <: semigroup, 'b <: 'a semigroup_action] (g : 'a) (x : 'b) = amul<:'a,'b> g x. +*) + +op big ['a, 'b <: monoid] (P : 'a -> bool) (F : 'a -> 'b) (r : 'a list) = + foldr mmul mid (map F (filter P r)). + + +(* ==================================================================== *) +(* Lemma examples *) + +(* -------------------------------------------------------------------- *) +(* Set theory *) + +lemma all_finiteP ['a <: finite] p : (all_finite p) <=> (forall (x : 'a), p x). +proof. by rewrite/all_finite allP; split=> Hp x; rewrite Hp enumP. qed. + +lemma all_countableP ['a <: countable] p : (all_countable p) <=> (forall (x : 'a), p x). +proof. + rewrite/all_countable; split => [Hp x|Hp n]. + by case (countP x) => n ->>; rewrite Hp. + by rewrite Hp. +qed. + +lemma all_finite_countable ['a <: finite & countable] (p : 'a -> bool) : (all_finite p) <=> (all_countable p). +proof. by rewrite all_finiteP all_countableP. qed. + + +(* ==================================================================== *) +(* Instance examples *) + +(* -------------------------------------------------------------------- *) +(* Set theory *) + +op bool_enum = [true; false]. + +(* TODO: we want to be able to give the list directly.*) +instance finite with bool + op enum = bool_enum. + +realize enumP. +proof. by case. qed. + +(* -------------------------------------------------------------------- *) +(* Advanced algebraic structures *) + +(* +op izero = 0. + +instance comgroup with int + op zero = izero + op ( + ) = CoreInt.add + op ([-]) = CoreInt.opp. + +(* TODO: might be any of the two addr0, also apply fails but rewrite works. + In ecScope, where instances are declared. *) +realize addr0 by rewrite addr0. +realize addrN by trivial. +realize addrC by rewrite addrC. +realize addrA by rewrite addrA. + +op foo = 1 + 3. + +print ( + ). +print foo. + +op ione = 1. + +(* TODO: this automatically fetches the only instance of comgroup we have defined for int. + We should give the choice of which instance to use, by adding as desired_name after the with. + Also we should give the choice to define directly an instance of comring with int. *) +instance comring with int + op one = ione + op ( * ) = CoreInt.mul. + +realize mulr1 by trivial. +realize mulrC by rewrite mulrC. +realize mulrA by rewrite mulrA. + +realize mulrDl. +proof. + (*TODO: in the goal, the typeclass operator + should have been replaced with the + from CoreInt, but has not been.*) + print mulrDl. + move => x y z. + class. + apply Ring.IntID.mulrDl. +qed. + +(* ==================================================================== *) +(* Misc *) + +(* -------------------------------------------------------------------- *) +(* TODO: which instance is kept in memory after this? *) + +op bool_enum_alt = [true; false]. + +instance finite with bool + op enum = bool_enum_alt. + +realize enumP. +proof. by case. qed. + +type class find_out <: finite = { + axiom rev_enum : rev<:find_out> enum = enum +}. + +instance find_out with bool. + +realize rev_enum. +proof. + admit. +qed. + + + +(* ==================================================================== *) +(* Old TODO list: 1-3 are done, modulo bugs, 4 is to be done, 5 will be done later. *) + +(* + 1. typage -> selection des operateurs / inference des instances de tc + 2. reduction + 3. unification (tactiques) + 4. clonage + 5. envoi au SMT + + 1. + Fop : + -(old) path * ty list -> form + -(new) path * (ty * (map tcname -> tcinstance)) list -> form + + op ['a <: monoid] (+) : 'a -> 'a -> 'a. + + (+)<:int + monoid -> intadd_monoid> + (+)<:int + monoid -> intmul_monoid> + + 1.1 module de construction des formules avec typage + 1.2 utiliser le module ci-dessous + + let module M = MkForm(struct let env = env' end) in + + 1.3 UnionFind avec contraintes de TC + + 1.4 Overloading: + 3 + 4 + a. 3 Int.(+) 4 + b. 3 Monoid<:int>.(+) 4 (-> instance du dessus -> ignore) + + 1.5 foo<: int[monoid -> intadd_monoid] > + foo<: int[monoid -> intmul_monoid] > + + 2. -> Monoid.(+)<:int> -> Int.(+) + + 3. -> Pb d'unification des op + (+)<: ?[monoid -> ?] > ~ Int.(+) + + Mecanisme de resolution des TC + + 4. -> il faut cloner les TC + + 5. + + a. encodage + + record 'a premonoid = { + op zero : 'a + op add : 'a -> 'a -> 'a; + } + + pred ['a] ismonoid (m : 'a premonoid) = { + left_id m.zero m.add + } + + op ['a <: monoid] foo (x y : 'a) = x + y + + ->> foo ['a] (m : 'a premonoid) (x y : 'a) = m.add x y + + lemma foo ['a <: monoid] P + + ->> foo ['a] (m : 'a premonoid) : ismonoid m => P + + let intmonoid = { zero = 0; add = intadd } + + lemma intmonoid_is_monoid : ismonoid int_monoid + + b. reduction avant envoi + (+)<: int[monoid -> intadd_monoid > -> Int.(+) + + c. ne pas envoyer certaines instances (e.g. int est un groupe) + -> instance [nosmt] e.g. +*) +*) diff --git a/src/ecAlgTactic.ml b/src/ecAlgTactic.ml index f926a7ff3b..faf5a01236 100644 --- a/src/ecAlgTactic.ml +++ b/src/ecAlgTactic.ml @@ -80,7 +80,7 @@ module Axioms = struct let addctt = fun subst x f -> EcSubst.add_opdef subst (xpath x) ([], f) in let subst = - EcSubst.add_tydef EcSubst.empty (xpath tname) ([], cr.r_type) in + EcSubst.add_tydef EcSubst.empty (xpath tname) ([], cr.r_type, []) in let subst = List.fold_left (fun subst (x, p) -> add subst x p) subst crcore in let subst = odfl subst (cr.r_opp |> omap (fun p -> add subst opp p)) in diff --git a/src/ecAst.ml b/src/ecAst.ml index dc04fe95e7..9b353472ce 100644 --- a/src/ecAst.ml +++ b/src/ecAst.ml @@ -3,7 +3,6 @@ open EcUtils open EcSymbols open EcIdent open EcPath -open EcUid module BI = EcBigInt @@ -33,7 +32,6 @@ type quantif = type hoarecmp = FHle | FHeq | FHge (* -------------------------------------------------------------------- *) - type 'a use_restr = { ur_pos : 'a option; (* If not None, can use only element in this set. *) ur_neg : 'a; (* Cannot use element in this set. *) @@ -42,6 +40,13 @@ type 'a use_restr = { type mr_xpaths = EcPath.Sx.t use_restr type mr_mpaths = EcPath.Sm.t use_restr +(* -------------------------------------------------------------------- *) +module TyUni = EcUid.CoreGen () +module TcUni = EcUid.CoreGen () + +type tyuni = TyUni.uid +type tcuni = TcUni.uid + (* -------------------------------------------------------------------- *) type ty = { ty_node : ty_node; @@ -51,12 +56,49 @@ type ty = { and ty_node = | Tglob of EcIdent.t (* The tuple of global variable of the module *) - | Tunivar of EcUid.uid + | Tunivar of tyuni | Tvar of EcIdent.t | Ttuple of ty list - | Tconstr of EcPath.path * ty list + | Tconstr of EcPath.path * etyarg list | Tfun of ty * ty +(* -------------------------------------------------------------------- *) +and etyarg = ty * tcwitness list + +and tcwitness = + (* Unification variable, possibly with a pending [lift] path to apply + once the variable is resolved. *) + | TCIUni of tcuni * int list + + | TCIConcrete of { + path: EcPath.path; + etyargs: (ty * tcwitness list) list; + (* Same semantics as [TCIAbstract.lift]. *) + lift: int list; + } + + | TCIAbstract of { + support: [ + | `Var of EcIdent.t + | `Abs of EcPath.path + ]; + offset: int; + (* Path through the parent DAG starting at the typeclass at + [support]'s [offset]-th position. [lift = []] means "use the + declared typeclass directly"; [lift = [i; j; ...]] means + "take parent index [i], then parent index [j] of that, ...". + For single-parent classes the path is always [0; 0; ...]. + For multi-parent (factory) classes, the path encodes which + parent edge is taken at each step. *) + lift: int list; + } + +(* -------------------------------------------------------------------- *) +and typeclass = { + tc_name : EcPath.path; + tc_args : etyarg list; +} + (* -------------------------------------------------------------------- *) and ovariable = { ov_name : EcSymbols.symbol option; @@ -84,7 +126,7 @@ and expr_node = | Eint of BI.zint (* int. literal *) | Elocal of EcIdent.t (* let-variables *) | Evar of prog_var (* module variable *) - | Eop of EcPath.path * ty list (* op apply to type args *) + | Eop of EcPath.path * etyarg list (* op apply to type args *) | Eapp of expr * expr list (* op. application *) | Equant of equantif * ebindings * expr (* fun/forall/exists *) | Elet of lpattern * expr * expr (* let binding *) @@ -185,7 +227,7 @@ and f_node = | Flocal of EcIdent.t | Fpvar of prog_var * memory | Fglob of EcIdent.t * memory - | Fop of EcPath.path * ty list + | Fop of EcPath.path * etyarg list | Fapp of form * form list | Ftuple of form list | Fproj of form * int @@ -781,6 +823,100 @@ let lp_fv = function (fun s (id, _) -> ofold Sid.add s id) Sid.empty ids +(* -------------------------------------------------------------------- *) +(* Append [extra] to a witness's [lift] path. Used during substitution + when a witness referencing the [k]-th tc of some support gets + replaced by the witness for that tc, which may itself need further + parent-walk steps. *) +let bump_lift (extra : int list) (tcw : tcwitness) : tcwitness = + if extra = [] then tcw else + match tcw with + | TCIUni (uid, l) -> TCIUni (uid, l @ extra) + | TCIConcrete c -> TCIConcrete { c with lift = c.lift @ extra } + | TCIAbstract a -> TCIAbstract { a with lift = a.lift @ extra } + +(* -------------------------------------------------------------------- *) +let rec tcw_fv (tcw : tcwitness) = + match tcw with + | TCIUni _ -> + Mid.empty + + | TCIConcrete { etyargs } -> + List.fold_left + (fun fv (ty, tcws) -> fv_union fv (fv_union ty.ty_fv (tcws_fv tcws))) + Mid.empty etyargs + + | TCIAbstract { support = `Var v } -> + Mid.singleton v 1 + + | TCIAbstract { support = `Abs _ } -> + Mid.empty + +and tcws_fv (tcws : tcwitness list) = + List.fold_left + (fun fv tcw -> fv_union fv (tcw_fv tcw)) + Mid.empty tcws + +let etyarg_fv ((ty, tcws) : etyarg) = + fv_union ty.ty_fv (tcws_fv tcws) + +let etyargs_fv (tyargs : etyarg list) = + List.fold_left + (fun fv tyarg -> fv_union fv (etyarg_fv tyarg)) + Mid.empty tyargs + +(* -------------------------------------------------------------------- *) +let rec tcw_equal (tcw1 : tcwitness) (tcw2 : tcwitness) = + match tcw1, tcw2 with + | TCIUni (uid1, l1), TCIUni (uid2, l2) -> + TcUni.uid_equal uid1 uid2 && l1 = l2 + + | TCIConcrete tcw1, TCIConcrete tcw2 -> + EcPath.p_equal tcw1.path tcw2.path + && tcw1.lift = tcw2.lift + && List.all2 etyarg_equal tcw1.etyargs tcw2.etyargs + + | TCIAbstract { support = support1; offset = o1; lift = l1 } + , TCIAbstract { support = support2; offset = o2; lift = l2 } + -> + let tyvar_eq () = + match support1, support2 with + | `Var x1, `Var x2 -> + EcIdent.id_equal x1 x2 + | `Abs p1, `Abs p2 -> + EcPath.p_equal p1 p2 + | _, _ -> false + + in o1 = o2 && l1 = l2 && tyvar_eq () + + | _, _ -> + false + +and etyarg_equal ((ty1, tcws1) : etyarg) ((ty2, tcws2) : etyarg) = + ty_equal ty1 ty2 && List.all2 tcw_equal tcws1 tcws2 + +(* -------------------------------------------------------------------- *) +let rec tcw_hash (tcw : tcwitness) = + let lift_hash = Why3.Hashcons.combine_list (fun i -> i) 0 in + match tcw with + | TCIUni (uid, l) -> + Why3.Hashcons.combine (Hashtbl.hash uid) (lift_hash l) + + | TCIConcrete tcw -> + Why3.Hashcons.combine_list + etyarg_hash + (Why3.Hashcons.combine (p_hash tcw.path) (lift_hash tcw.lift)) + tcw.etyargs + + | TCIAbstract { support = `Var tyvar; offset; lift } -> + Why3.Hashcons.combine2 (EcIdent.id_hash tyvar) offset (lift_hash lift) + + | TCIAbstract { support = `Abs p; offset; lift } -> + Why3.Hashcons.combine2 (EcPath.p_hash p) offset (lift_hash lift) + + and etyarg_hash ((ty, tcws) : etyarg) = + Why3.Hashcons.combine_list tcw_hash (ty_hash ty) tcws + (* -------------------------------------------------------------------- *) let e_equal = ((==) : expr -> expr -> bool) let e_hash = fun e -> e.e_tag @@ -791,7 +927,6 @@ let eqt_equal : equantif -> equantif -> bool = (==) let eqt_hash : equantif -> int = Hashtbl.hash (* -------------------------------------------------------------------- *) - let lv_equal lv1 lv2 = match lv1, lv2 with | LvVar (pv1, ty1), LvVar (pv2, ty2) -> @@ -815,7 +950,6 @@ let lv_fv = function let add s (pv, _) = EcIdent.fv_union s (pv_fv pv) in List.fold_left add Mid.empty pvs - let lv_hash = function | LvVar (pv, ty) -> Why3.Hashcons.combine (pv_hash pv) (ty_hash ty) @@ -825,7 +959,6 @@ let lv_hash = function (fun (pv, ty) -> Why3.Hashcons.combine (pv_hash pv) (ty_hash ty)) 0 pvs - (* -------------------------------------------------------------------- *) let i_equal = ((==) : instr -> instr -> bool) let i_hash = fun i -> i.i_tag @@ -835,7 +968,6 @@ let s_equal = ((==) : stmt -> stmt -> bool) let s_hash = fun s -> s.s_tag let s_fv = fun s -> s.s_fv - (*-------------------------------------------------------------------- *) let qt_equal : quantif -> quantif -> bool = (==) let qt_hash : quantif -> int = Hashtbl.hash @@ -1216,7 +1348,7 @@ module Hsty = Why3.Hashcons.Make (struct EcIdent.id_equal m1 m2 | Tunivar u1, Tunivar u2 -> - uid_equal u1 u2 + TyUni.uid_equal u1 u2 | Tvar v1, Tvar v2 -> id_equal v1 v2 @@ -1225,7 +1357,7 @@ module Hsty = Why3.Hashcons.Make (struct List.all2 ty_equal lt1 lt2 | Tconstr (p1, lt1), Tconstr (p2, lt2) -> - EcPath.p_equal p1 p2 && List.all2 ty_equal lt1 lt2 + EcPath.p_equal p1 p2 && List.all2 etyarg_equal lt1 lt2 | Tfun (d1, c1), Tfun (d2, c2)-> ty_equal d1 d2 && ty_equal c1 c2 @@ -1235,10 +1367,10 @@ module Hsty = Why3.Hashcons.Make (struct let hash ty = match ty.ty_node with | Tglob m -> EcIdent.id_hash m - | Tunivar u -> u + | Tunivar u -> Hashtbl.hash u | Tvar id -> EcIdent.tag id | Ttuple tl -> Why3.Hashcons.combine_list ty_hash 0 tl - | Tconstr (p, tl) -> Why3.Hashcons.combine_list ty_hash p.p_tag tl + | Tconstr (p, tl) -> Why3.Hashcons.combine_list etyarg_hash p.p_tag tl | Tfun (t1, t2) -> Why3.Hashcons.combine (ty_hash t1) (ty_hash t2) let fv ty = @@ -1250,7 +1382,7 @@ module Hsty = Why3.Hashcons.Make (struct | Tunivar _ -> Mid.empty | Tvar _ -> Mid.empty (* FIXME: section *) | Ttuple tys -> union (fun a -> a.ty_fv) tys - | Tconstr (_, tys) -> union (fun a -> a.ty_fv) tys + | Tconstr (_, tys) -> union etyarg_fv tys | Tfun (t1, t2) -> union (fun a -> a.ty_fv) [t1; t2] let tag n ty = { ty with ty_tag = n; ty_fv = fv ty.ty_node; } @@ -1260,7 +1392,6 @@ let mk_ty node = Hsty.hashcons { ty_node = node; ty_tag = -1; ty_fv = Mid.empty } (* ----------------------------------------------------------------- *) - module Hexpr = Why3.Hashcons.Make (struct type t = expr @@ -1277,7 +1408,7 @@ module Hexpr = Why3.Hashcons.Make (struct | Eop (p1, tys1), Eop (p2, tys2) -> (EcPath.p_equal p1 p2) - && (List.all2 ty_equal tys1 tys2) + && (List.all2 etyarg_equal tys1 tys2) | Eapp (e1, es1), Eapp (e2, es2) -> (e_equal e1 e2) @@ -1320,9 +1451,8 @@ module Hexpr = Why3.Hashcons.Make (struct | Elocal x -> Hashtbl.hash x | Evar x -> pv_hash x - | Eop (p, tys) -> - Why3.Hashcons.combine_list ty_hash - (EcPath.p_hash p) tys + | Eop (p, tyargs) -> + Why3.Hashcons.combine_list etyarg_hash (EcPath.p_hash p) tyargs | Eapp (e, es) -> Why3.Hashcons.combine_list e_hash (e_hash e) es @@ -1359,7 +1489,7 @@ module Hexpr = Why3.Hashcons.Make (struct match e with | Eint _ -> Mid.empty - | Eop (_, tys) -> union (fun a -> a.ty_fv) tys + | Eop (_, tyargs) -> etyargs_fv tyargs | Evar v -> pv_fv v | Elocal id -> fv_singleton id | Eapp (e, es) -> union e_fv (e :: es) @@ -1376,7 +1506,27 @@ module Hexpr = Why3.Hashcons.Make (struct end) (* -------------------------------------------------------------------- *) -let mk_expr e ty = +let normalize_enode (node : expr_node) : expr_node = + match node with + | Equant (_, [], body) -> + body.e_node + + | Equant (q1, bds1, { e_node = Equant (q2, bds2, body) }) + when q1 = q2 + -> Equant (q1, bds1 @ bds2, body) + + | Eapp (hd, []) -> + hd.e_node + + | Eapp ({ e_node = Eapp (hd, args1) }, args2) -> + Eapp (hd, args1 @ args2) + + | _ -> + node + +(* -------------------------------------------------------------------- *) +let mk_expr (e : expr_node) (ty : ty) = + let e = normalize_enode e in Hexpr.hashcons { e_node = e; e_tag = -1; e_fv = Mid.empty; e_ty = ty } (* -------------------------------------------------------------------- *) @@ -1411,7 +1561,7 @@ module Hsform = Why3.Hashcons.Make (struct EcIdent.id_equal mp1 mp2 && EcIdent.id_equal m1 m2 | Fop(p1,lty1), Fop(p2,lty2) -> - EcPath.p_equal p1 p2 && List.all2 ty_equal lty1 lty2 + EcPath.p_equal p1 p2 && List.all2 etyarg_equal lty1 lty2 | Fapp(f1,args1), Fapp(f2,args2) -> f_equal f1 f2 && List.all2 f_equal args1 args2 @@ -1465,8 +1615,10 @@ module Hsform = Why3.Hashcons.Make (struct | Fglob(mp, m) -> Why3.Hashcons.combine (EcIdent.id_hash mp) (EcIdent.id_hash m) - | Fop(p, lty) -> - Why3.Hashcons.combine_list ty_hash (EcPath.p_hash p) lty + | Fop(p, tyargs) -> + Why3.Hashcons.combine_list + etyarg_hash (EcPath.p_hash p) + tyargs | Fapp(f, args) -> Why3.Hashcons.combine_list f_hash (f_hash f) args @@ -1505,7 +1657,7 @@ module Hsform = Why3.Hashcons.Make (struct match f with | Fint _ -> Mid.empty - | Fop (_, tys) -> union (fun a -> a.ty_fv) tys + | Fop (_, tyargs) -> union etyarg_fv tyargs | Fpvar (PVglob pv,m) -> EcPath.x_fv (fv_add m Mid.empty) pv | Fpvar (PVloc _,m) -> fv_add m Mid.empty | Fglob (mp,m) -> fv_add mp (fv_add m Mid.empty) @@ -1581,7 +1733,28 @@ module Hsform = Why3.Hashcons.Make (struct { f with f_tag = n; f_fv = fv; } end) -let mk_form node ty = +(* -------------------------------------------------------------------- *) +let normalize_fnode (node : f_node) : f_node = + match node with + | Fquant (_, [], body) -> + body.f_node + + | Fquant (q1, bds1, { f_node = Fquant (q2, bds2, body) }) + when q1 = q2 + -> Fquant (q1, bds1 @ bds2, body) + + | Fapp (hd, []) -> + hd.f_node + + | Fapp ({ f_node = Fapp (hd, args1)}, args2) -> + Fapp (hd, args1 @ args2) + + | _ -> + node + +(* -------------------------------------------------------------------- *) +let mk_form (node : f_node) (ty : ty) = + let node = normalize_fnode (node) in let aout = Hsform.hashcons { f_node = node; diff --git a/src/ecAst.mli b/src/ecAst.mli index a13023aec3..96cd7fa6db 100644 --- a/src/ecAst.mli +++ b/src/ecAst.mli @@ -36,6 +36,13 @@ type mr_xpaths = EcPath.Sx.t use_restr type mr_mpaths = EcPath.Sm.t use_restr +(* -------------------------------------------------------------------- *) +module TyUni : EcUid.ICore with type uid = private EcUid.uid +module TcUni : EcUid.ICore with type uid = private EcUid.uid + +type tyuni = TyUni.uid +type tcuni = TcUni.uid + (* -------------------------------------------------------------------- *) type ty = private { ty_node : ty_node; @@ -45,12 +52,39 @@ type ty = private { and ty_node = | Tglob of EcIdent.t (* The tuple of global variable of the module *) - | Tunivar of EcUid.uid + | Tunivar of tyuni | Tvar of EcIdent.t | Ttuple of ty list - | Tconstr of EcPath.path * ty list + | Tconstr of EcPath.path * etyarg list | Tfun of ty * ty +(* -------------------------------------------------------------------- *) +and etyarg = ty * tcwitness list + +and tcwitness = + | TCIUni of tcuni * int list + + | TCIConcrete of { + path: EcPath.path; + etyargs: (ty * tcwitness list) list; + lift: int list; + } + + | TCIAbstract of { + support: [ + | `Var of EcIdent.t + | `Abs of EcPath.path + ]; + offset: int; + lift: int list; + } + +(* -------------------------------------------------------------------- *) +and typeclass = { + tc_name : EcPath.path; + tc_args : etyarg list; +} + (* -------------------------------------------------------------------- *) and ovariable = { ov_name : EcSymbols.symbol option; @@ -78,7 +112,7 @@ and expr_node = | Eint of BI.zint (* int. literal *) | Elocal of EcIdent.t (* let-variables *) | Evar of prog_var (* module variable *) - | Eop of EcPath.path * ty list (* op apply to type args *) + | Eop of EcPath.path * etyarg list (* op apply to type args *) | Eapp of expr * expr list (* op. application *) | Equant of equantif * ebindings * expr (* fun/forall/exists *) | Elet of lpattern * expr * expr (* let binding *) @@ -91,7 +125,6 @@ and ebinding = EcIdent.t * ty and ebindings = ebinding list (* -------------------------------------------------------------------- *) - and lvalue = | LvVar of (prog_var * ty) | LvTuple of (prog_var * ty) list @@ -179,7 +212,7 @@ and f_node = | Flocal of EcIdent.t | Fpvar of prog_var * memory | Fglob of EcIdent.t * memory - | Fop of EcPath.path * ty list + | Fop of EcPath.path * etyarg list | Fapp of form * form list | Ftuple of form list | Fproj of form * int @@ -489,6 +522,18 @@ val lp_equal : lpattern equality val lp_hash : lpattern hash val lp_fv : lpattern -> EcIdent.Sid.t +(* -------------------------------------------------------------------- *) +val etyarg_fv : etyarg -> int Mid.t +val etyargs_fv : etyarg list -> int Mid.t +val etyarg_hash : etyarg -> int +val etyarg_equal : etyarg -> etyarg -> bool + +(* -------------------------------------------------------------------- *) +val bump_lift : int list -> tcwitness -> tcwitness +val tcw_fv : tcwitness -> int Mid.t +val tcw_hash : tcwitness -> int +val tcw_equal : tcwitness -> tcwitness -> bool + (* -------------------------------------------------------------------- *) val e_equal : expr equality val e_hash : expr hash diff --git a/src/ecCallbyValue.ml b/src/ecCallbyValue.ml index 227ed19d11..c534c5f7df 100644 --- a/src/ecCallbyValue.ml +++ b/src/ecCallbyValue.ml @@ -217,7 +217,7 @@ and betared st s bd f args = (* -------------------------------------------------------------------- *) and try_reduce_record_projection - (st : state) ((p, _tys) : EcPath.path * ty list) (args : args) + (st : state) ((p, _tys) : EcPath.path * etyarg list) (args : args) = let exception Bailout in @@ -245,7 +245,7 @@ and try_reduce_record_projection (* -------------------------------------------------------------------- *) and try_reduce_fixdef - (st : state) ((p, tys) : EcPath.path * ty list) (args : args) + (st : state) ((p, tys) : EcPath.path * etyarg list) (args : args) = let exception Bailout in @@ -300,7 +300,9 @@ and try_reduce_fixdef let body = EcFol.form_of_expr body in let body = - Tvar.f_subst ~freshen:true op.EcDecl.op_tparams tys body in + Tvar.f_subst ~freshen:true + (List.combine (List.map fst op.EcDecl.op_tparams) tys) + body in Some (cbv st subst body (Args.create ty eargs)) diff --git a/src/ecCommands.ml b/src/ecCommands.ml index 2b08923ea9..ae73886d04 100644 --- a/src/ecCommands.ml +++ b/src/ecCommands.ml @@ -434,6 +434,13 @@ and process_subtype (scope : EcScope.scope) (subtype : psubtype located) = EcScope.notify scope `Info "added subtype: `%s'" (unloc subtype.pl_desc.pst_name); scope +(* -------------------------------------------------------------------- *) +and process_typeclass (scope : EcScope.scope) (tcd : ptypeclass located) = + EcScope.check_state `InTop "type class" scope; + let scope = EcScope.Ty.add_class scope tcd in + EcScope.notify scope `Info "added type class: `%s'" (unloc tcd.pl_desc.ptc_name); + scope + (* -------------------------------------------------------------------- *) and process_tycinst (scope : EcScope.scope) (tci : ptycinstance located) = EcScope.check_state `InTop "type class instance" scope; @@ -776,6 +783,7 @@ and process ?(src : string option) (ld : Loader.loader) (scope : EcScope.scope) match g.pl_desc with | Gtype t -> `Fct (fun scope -> process_types ?src scope (List.map (mk_loc loc) t)) | Gsubtype t -> `Fct (fun scope -> process_subtype scope (mk_loc loc t)) + | Gtypeclass t -> `Fct (fun scope -> process_typeclass scope (mk_loc loc t)) | Gtycinstance t -> `Fct (fun scope -> process_tycinst scope (mk_loc loc t)) | Gmodule m -> `Fct (fun scope -> process_module ?src scope m) | Ginterface i -> `Fct (fun scope -> process_interface ?src scope i) diff --git a/src/ecCoreEqTest.ml b/src/ecCoreEqTest.ml new file mode 100644 index 0000000000..9dc60f4059 --- /dev/null +++ b/src/ecCoreEqTest.ml @@ -0,0 +1,89 @@ +(* -------------------------------------------------------------------- + * Copyright (c) - 2012--2016 - IMDEA Software Institute + * Copyright (c) - 2012--2018 - Inria + * Copyright (c) - 2012--2018 - Ecole Polytechnique + * + * Distributed under the terms of the CeCILL-C-V1 license + * -------------------------------------------------------------------- *) + +(* -------------------------------------------------------------------- *) +open EcUtils +open EcTypes +open EcEnv + +(* -------------------------------------------------------------------- *) +type 'a eqtest = env -> 'a -> 'a -> bool + +(* -------------------------------------------------------------------- *) +let rec for_type env t1 t2 = + ty_equal t1 t2 || for_type_r env t1 t2 + +(* -------------------------------------------------------------------- *) +and for_type_r env t1 t2 = + match t1.ty_node, t2.ty_node with + | Tunivar uid1, Tunivar uid2 -> + EcAst.TyUni.uid_equal uid1 uid2 + + | Tvar i1, Tvar i2 -> i1 = i2 + + | Ttuple lt1, Ttuple lt2 -> + List.length lt1 = List.length lt2 + && List.all2 (for_type env) lt1 lt2 + + | Tfun (t1, t2), Tfun (t1', t2') -> + for_type env t1 t1' && for_type env t2 t2' + + | Tglob m1, Tglob m2 -> EcIdent.id_equal m1 m2 + + | Tconstr (p1, lt1), Tconstr (p2, lt2) when EcPath.p_equal p1 p2 -> + if + List.length lt1 = List.length lt2 + && List.all2 (for_etyarg env) lt1 lt2 + then true + else + if Ty.defined p1 env + then for_type env (Ty.unfold p1 lt1 env) (Ty.unfold p2 lt2 env) + else false + + | Tconstr (p1, lt1), _ when Ty.defined p1 env -> + for_type env (Ty.unfold p1 lt1 env) t2 + + | _, Tconstr (p2, lt2) when Ty.defined p2 env -> + for_type env t1 (Ty.unfold p2 lt2 env) + + | _, _ -> false + +(* -------------------------------------------------------------------- *) +and for_etyarg env ((ty1, tcws1) : etyarg) ((ty2, tcws2) : etyarg) = + for_type env ty1 ty2 && for_tcws env tcws1 tcws2 + +and for_etyargs env (tyargs1 : etyarg list) (tyargs2 : etyarg list) = + List.length tyargs1 = List.length tyargs2 + && List.for_all2 (for_etyarg env) tyargs1 tyargs2 + +and for_tcw env (tcw1 : tcwitness) (tcw2 : tcwitness) = + let tcw1 = EcTcCanonical.canonicalise_witness env tcw1 in + let tcw2 = EcTcCanonical.canonicalise_witness env tcw2 in + match tcw1, tcw2 with + | TCIUni (uid1, l1), TCIUni (uid2, l2) -> + EcAst.TcUni.uid_equal uid1 uid2 && l1 = l2 + + | TCIConcrete tcw1, TCIConcrete tcw2 -> + EcPath.p_equal tcw1.path tcw2.path + && tcw1.lift = tcw2.lift + && for_etyargs env tcw1.etyargs tcw2.etyargs + + | TCIAbstract { support = `Var v1; offset = o1; lift = l1 }, + TCIAbstract { support = `Var v2; offset = o2; lift = l2 } -> + EcIdent.id_equal v1 v2 && o1 = o2 && l1 = l2 + + | TCIAbstract { support = `Abs p1; offset = o1; lift = l1 }, + TCIAbstract { support = `Abs p2; offset = o2; lift = l2 } -> + EcPath.p_equal p1 p2 && o1 = o2 && l1 = l2 + + | _, _ -> + false + +and for_tcws env (tcws1 : tcwitness list) (tcws2 : tcwitness list) = + List.length tcws1 = List.length tcws2 + && List.for_all2 (for_tcw env) tcws1 tcws2 diff --git a/src/ecCoreEqTest.mli b/src/ecCoreEqTest.mli new file mode 100644 index 0000000000..aa6e5f705b --- /dev/null +++ b/src/ecCoreEqTest.mli @@ -0,0 +1,9 @@ +(* -------------------------------------------------------------------- *) +open EcTypes +open EcEnv + +(* -------------------------------------------------------------------- *) +type 'a eqtest = env -> 'a -> 'a -> bool + +val for_type : ty eqtest +val for_etyarg : etyarg eqtest diff --git a/src/ecCoreFol.ml b/src/ecCoreFol.ml index aae2bacbeb..092eff588c 100644 --- a/src/ecCoreFol.ml +++ b/src/ecCoreFol.ml @@ -153,7 +153,8 @@ let mk_form = EcAst.mk_form let f_node { f_node = form } = form (* -------------------------------------------------------------------- *) -let f_op x tys ty = mk_form (Fop (x, tys)) ty +let f_op x tys ty = mk_form (Fop (x, List.map (fun t -> (t, [])) tys)) ty +let f_op_tc x tys ty = mk_form (Fop (x, tys)) ty let f_app f args ty = let f, args' = @@ -463,9 +464,13 @@ let f_map gt g fp = (f_pvar id ty' s).inv | Fop (p, tys) -> - let tys' = List.Smart.map gt tys in + let tys' = + List.Smart.map + (fun ((t, w) as ety) -> + let t' = gt t in if t == t' then ety else (t', w)) + tys in let ty' = gt fp.f_ty in - f_op p tys' ty' + f_op_tc p tys' ty' | Fapp (f, fs) -> let f' = g f in @@ -956,7 +961,7 @@ let rec form_of_expr_r ?m (e : expr) = end | Eop (op, tys) -> - f_op op tys e.e_ty + f_op_tc op tys e.e_ty | Eapp (ef, es) -> f_app (form_of_expr_r ?m ef) (List.map (form_of_expr_r ?m) es) e.e_ty @@ -1001,7 +1006,7 @@ let expr_of_ss_inv f = | Fint z -> e_int z | Flocal x -> e_local x fp.f_ty - | Fop (p, tys) -> e_op p tys fp.f_ty + | Fop (p, tys) -> e_op_tc p tys fp.f_ty | Fapp (f, fs) -> e_app (aux f) (List.map aux fs) fp.f_ty | Ftuple fs -> e_tuple (List.map aux fs) | Fproj (f, i) -> e_proj (aux f) i fp.f_ty @@ -1043,7 +1048,7 @@ let expr_of_form f = | Fint z -> e_int z | Flocal x -> e_local x fp.f_ty - | Fop (p, tys) -> e_op p tys fp.f_ty + | Fop (p, tys) -> e_op_tc p tys fp.f_ty | Fapp (f, fs) -> e_app (aux f) (List.map aux fs) fp.f_ty | Ftuple fs -> e_tuple (List.map aux fs) | Fproj (f, i) -> e_proj (aux f) i fp.f_ty diff --git a/src/ecCoreFol.mli b/src/ecCoreFol.mli index b270d12d5e..2fd550cc37 100644 --- a/src/ecCoreFol.mli +++ b/src/ecCoreFol.mli @@ -12,9 +12,7 @@ open EcMemory type quantif = EcAst.quantif type hoarecmp = EcAst.hoarecmp - -type gty = EcAst.gty - +type gty = EcAst.gty type binding = (EcIdent.t * gty) type bindings = binding list @@ -74,7 +72,7 @@ val f_node : form -> f_node (* -------------------------------------------------------------------- *) (* not recursive *) -val f_map : (EcTypes.ty -> EcTypes.ty) -> (form -> form) -> form -> form +val f_map : (ty -> ty) -> (form -> form) -> form -> form val f_iter : (form -> unit) -> form -> unit val f_fold : ('a -> form -> 'a) -> 'a -> form -> 'a @@ -96,6 +94,7 @@ val f_glob : EcIdent.t -> memory -> ss_inv (* soft-constructors - common formulas constructors *) val f_op : path -> EcTypes.ty list -> EcTypes.ty -> form +val f_op_tc : path -> etyarg list -> EcTypes.ty -> form val f_app : form -> form list -> EcTypes.ty -> form val f_tuple : form list -> form val f_proj : form -> int -> EcTypes.ty -> form @@ -141,6 +140,7 @@ val f_equivF : ts_inv -> xpath -> xpath -> ts_inv -> form val f_equivS : memtype -> memtype -> ts_inv -> stmt -> stmt -> ts_inv -> form (* soft-constructors - eager *) +val f_eagerF_r : eagerF -> form val f_eagerF : ts_inv -> stmt -> xpath -> xpath -> stmt -> ts_inv -> form (* soft-constructors - Pr *) @@ -250,13 +250,13 @@ val destr_forall1 : form -> ident * gty * form val destr_exists1 : form -> ident * gty * form val destr_lambda1 : form -> ident * gty * form -val destr_op : form -> EcPath.path * ty list +val destr_op : form -> EcPath.path * etyarg list val destr_local : form -> EcIdent.t val destr_pvar : form -> prog_var * memory val destr_proj : form -> form * int val destr_tuple : form -> form list val destr_app : form -> form * form list -val destr_op_app : form -> (EcPath.path * ty list) * form list +val destr_op_app : form -> (EcPath.path * etyarg list) * form list val destr_not : form -> form val destr_nots : form -> bool * form val destr_and : form -> form * form diff --git a/src/ecCoreGoal.ml b/src/ecCoreGoal.ml index 4b75e5b0c1..873df1baa1 100644 --- a/src/ecCoreGoal.ml +++ b/src/ecCoreGoal.ml @@ -51,7 +51,7 @@ and pt_head = | PTCut of EcFol.form * cutsolve option | PTHandle of handle | PTLocal of EcIdent.t -| PTGlobal of EcPath.path * (ty list) +| PTGlobal of EcPath.path * etyarg list | PTTerm of proofterm and cutsolve = [`Done | `Smt | `DoneSmt] diff --git a/src/ecCoreGoal.mli b/src/ecCoreGoal.mli index f574b49bf3..7725546407 100644 --- a/src/ecCoreGoal.mli +++ b/src/ecCoreGoal.mli @@ -53,7 +53,7 @@ and pt_head = | PTCut of EcFol.form * cutsolve option | PTHandle of handle | PTLocal of EcIdent.t -| PTGlobal of EcPath.path * (ty list) +| PTGlobal of EcPath.path * etyarg list | PTTerm of proofterm and cutsolve = [`Done | `Smt | `DoneSmt] @@ -82,12 +82,12 @@ val pamemory : EcMemory.memory -> pt_arg val pamodule : EcPath.mpath * EcModules.module_sig -> pt_arg (* -------------------------------------------------------------------- *) -val paglobal : ?args:pt_arg list -> tys:ty list -> EcPath.path -> pt_arg +val paglobal : ?args:pt_arg list -> tys:etyarg list -> EcPath.path -> pt_arg val palocal : ?args:pt_arg list -> EcIdent.t -> pt_arg val pahandle : ?args:pt_arg list -> handle -> pt_arg (* -------------------------------------------------------------------- *) -val ptglobal : ?args:pt_arg list -> tys:ty list -> EcPath.path -> proofterm +val ptglobal : ?args:pt_arg list -> tys:etyarg list -> EcPath.path -> proofterm val ptlocal : ?args:pt_arg list -> EcIdent.t -> proofterm val pthandle : ?args:pt_arg list -> handle -> proofterm val ptcut : ?args:pt_arg list -> ?cutsolve:cutsolve -> EcFol.form -> proofterm diff --git a/src/ecCorePrinting.ml b/src/ecCorePrinting.ml index dc89869590..3eddc01ba1 100644 --- a/src/ecCorePrinting.ml +++ b/src/ecCorePrinting.ml @@ -4,8 +4,7 @@ module type PrinterAPI = sig open EcIdent open EcSymbols open EcPath - open EcTypes - open EcFol + open EcAst open EcDecl open EcModules open EcTheory @@ -59,7 +58,8 @@ module type PrinterAPI = sig val pp_mem : PPEnv.t -> EcIdent.t pp val pp_memtype : PPEnv.t -> EcMemory.memtype pp val pp_tyvar : PPEnv.t -> ident pp - val pp_tyunivar : PPEnv.t -> EcUid.uid pp + val pp_tyunivar : PPEnv.t -> tyuni pp + val pp_tcunivar : PPEnv.t -> tcuni pp val pp_path : path pp (* ------------------------------------------------------------------ *) @@ -86,11 +86,12 @@ module type PrinterAPI = sig | `Glob of EcIdent.t * EcMemory.memory | `PVar of EcTypes.prog_var * EcMemory.memory ] - + val pp_vsubst : PPEnv.t -> vsubst pp (* ------------------------------------------------------------------ *) val pp_typedecl : PPEnv.t -> (path * tydecl ) pp + val pp_typeclass : PPEnv.t -> (typeclass ) pp val pp_opdecl : ?long:bool -> PPEnv.t -> (path * operator ) pp val pp_added_op : PPEnv.t -> operator pp val pp_axiom : ?long:bool -> PPEnv.t -> (path * axiom ) pp diff --git a/src/ecCoreSubst.ml b/src/ecCoreSubst.ml index 14368120d7..1a4609c199 100644 --- a/src/ecCoreSubst.ml +++ b/src/ecCoreSubst.ml @@ -23,8 +23,17 @@ type sc_instantiate = { (* -------------------------------------------------------------------- *) type f_subst = { fs_freshen : bool; (* true means freshen locals *) - fs_u : ty Muid.t; + fs_u : ty TyUni.Muid.t; fs_v : ty Mid.t; + (* Witnesses to use when substituting [TCIAbstract `Var x] for a + type variable x that is being replaced by [fs_v]. The list is + indexed by witness offset. Empty list / missing key means: leave + the witness alone (caller is doing alpha-renaming, not + instantiation). *) + fs_tw : tcwitness list Mid.t; + (* Resolutions for TCIUni witnesses (typically extracted from the + unifier's tcenv.resolution after a matching/unification step). *) + fs_tw_uni : tcwitness TcUni.Muid.t; fs_mod : EcPath.mpath Mid.t; fs_modex : mod_extra Mid.t; fs_loc : form Mid.t; @@ -56,12 +65,14 @@ let fv_Mid (type a) (* -------------------------------------------------------------------- *) let f_subst_init ?(freshen=false) - ?(tu=Muid.empty) + ?(tu=TyUni.Muid.empty) ?(tv=Mid.empty) + ?(tw=Mid.empty) + ?(tw_uni=TcUni.Muid.empty) ?(esloc=Mid.empty) () = let fv = Mid.empty in - let fv = Muid.fold (fun _ t s -> fv_union s (ty_fv t)) tu fv in + let fv = TyUni.Muid.fold (fun _ t s -> fv_union s (ty_fv t)) tu fv in let fv = fv_Mid ty_fv tv fv in let fv = fv_Mid e_fv esloc fv in @@ -69,6 +80,8 @@ let f_subst_init fs_freshen = freshen; fs_u = tu; fs_v = tv; + fs_tw = tw; + fs_tw_uni = tw_uni; fs_mod = Mid.empty; fs_modex = Mid.empty; fs_loc = Mid.empty; @@ -158,8 +171,10 @@ let f_rem_mod (s : f_subst) (x : ident) : f_subst = (* -------------------------------------------------------------------- *) let is_ty_subst_id (s : f_subst) : bool = Mid.is_empty s.fs_mod - && Muid.is_empty s.fs_u + && TyUni.Muid.is_empty s.fs_u && Mid.is_empty s.fs_v + && Mid.is_empty s.fs_tw + && TcUni.Muid.is_empty s.fs_tw_uni (* -------------------------------------------------------------------- *) let rec ty_subst (s : f_subst) (ty : ty) : ty = @@ -169,7 +184,7 @@ let rec ty_subst (s : f_subst) (ty : ty) : ty = |> Option.map (fun ex -> ex.mex_tglob) |> Option.value ~default:ty | Tunivar id -> - Muid.find_opt id s.fs_u + TyUni.Muid.find_opt id s.fs_u |> Option.map (ty_subst s) |> Option.value ~default:ty | Tvar id -> @@ -181,6 +196,31 @@ let rec ty_subst (s : f_subst) (ty : ty) : ty = let ty_subst (s : f_subst) : ty -> ty = if is_ty_subst_id s then identity else ty_subst s +(* -------------------------------------------------------------------- *) +let rec tcw_subst (s : f_subst) (tcw : tcwitness) : tcwitness = + match tcw with + | TCIAbstract { support = `Var x; offset; lift } when Mid.mem x s.fs_tw -> + let ws = Mid.find x s.fs_tw in + if offset < List.length ws then + bump_lift lift (tcw_subst s (List.nth ws offset)) + else + tcw + | TCIAbstract _ -> tcw + | TCIUni (uid, lift) when TcUni.Muid.mem uid s.fs_tw_uni -> + bump_lift lift (tcw_subst s (TcUni.Muid.find uid s.fs_tw_uni)) + | TCIUni _ -> tcw + | TCIConcrete c -> + let etyargs' = List.Smart.map (etyarg_subst_inner s) c.etyargs in + if etyargs' == c.etyargs then tcw + else TCIConcrete { c with etyargs = etyargs' } + +and etyarg_subst_inner (s : f_subst) ((ty, ws) as e : etyarg) : etyarg = + let ty' = ty_subst s ty in + let ws' = List.Smart.map (tcw_subst s) ws in + if ty == ty' && ws == ws' then e else (ty', ws') + +let etyarg_subst (s : f_subst) (e : etyarg) : etyarg = etyarg_subst_inner s e + (* -------------------------------------------------------------------- *) let is_e_subst_id (s : f_subst) = not s.fs_freshen @@ -256,9 +296,9 @@ let rec e_subst (s : f_subst) (e : expr) : expr = e_var pv' ty' | Eop (p, tys) -> - let tys' = List.Smart.map (ty_subst s) tys in + let tys' = List.Smart.map (etyarg_subst s) tys in let ty' = ty_subst s e.e_ty in - e_op p tys' ty' + e_op_tc p tys' ty' | Elet (lp, e1, e2) -> let e1' = e_subst s e1 in @@ -433,8 +473,9 @@ module Fsubst = struct | Fop (p, tys) -> let ty' = ty_subst s fp.f_ty in - let tys' = List.Smart.map (ty_subst s) tys in - f_op p tys' ty' + let tys' = List.Smart.map (etyarg_subst s) tys in + if ty' == fp.f_ty && tys' == tys then fp + else f_op_tc p tys' ty' | Fpvar (pv, m) -> let pv' = pv_subst s pv in @@ -681,57 +722,64 @@ module Fsubst = struct let init_subst_tvar ~(freshen : bool) (s : ty Mid.t) : f_subst = f_subst_init ~freshen ~tv:s () - let f_subst_tvar ~(freshen : bool) (s : ty Mid.t) : form -> form = - f_subst (init_subst_tvar ~freshen s) + let f_subst_tvar ~(freshen : bool) (s : etyarg Mid.t) : form -> form = + let tv = Mid.map fst s in + let tw = Mid.map snd s in + f_subst (f_subst_init ~freshen ~tv ~tw ()) end (* -------------------------------------------------------------------- *) module Tuni = struct - let subst (uidmap : ty Muid.t) : f_subst = - f_subst_init ~tu:uidmap () + let subst ?(tw_uni = TcUni.Muid.empty) (uidmap : ty TyUni.Muid.t) : f_subst = + f_subst_init ~tu:uidmap ~tw_uni () - let subst1 ((id, t) : uid * ty) : f_subst = - subst (Muid.singleton id t) + let subst1 ((id, t) : tyuni * ty) : f_subst = + subst (TyUni.Muid.singleton id t) - let subst_dom (uidmap : ty Muid.t) (dom : dom) : dom = + let subst_dom (uidmap : ty TyUni.Muid.t) (dom : dom) : dom = List.map (ty_subst (subst uidmap)) dom - let occurs (u : uid) : ty -> bool = + let occurs (u : tyuni) : ty -> bool = let rec aux t = match t.ty_node with - | Tunivar u' -> uid_equal u u' + | Tunivar u' -> TyUni.uid_equal u u' | _ -> ty_sub_exists aux t in aux - let univars : ty -> Suid.t = + let univars : ty -> TyUni.Suid.t = let rec doit univars t = match t.ty_node with - | Tunivar uid -> Suid.add uid univars + | Tunivar uid -> TyUni.Suid.add uid univars | _ -> ty_fold doit univars t - in fun t -> doit Suid.empty t + in fun t -> doit TyUni.Suid.empty t - let rec fv_rec (fv : Suid.t) (t : ty) : Suid.t = + let rec fv_rec (fv : TyUni.Suid.t) (t : ty) : TyUni.Suid.t = match t.ty_node with - | Tunivar id -> Suid.add id fv + | Tunivar id -> TyUni.Suid.add id fv | _ -> ty_fold fv_rec fv t - let fv (ty : ty) : Suid.t = - fv_rec Suid.empty ty + let fv (ty : ty) : TyUni.Suid.t = + fv_rec TyUni.Suid.empty ty end (* -------------------------------------------------------------------- *) module Tvar = struct - let subst (s : ty Mid.t) (ty : ty) : ty = - ty_subst { f_subst_id with fs_v = s } ty + let subst (s : etyarg Mid.t) (ty : ty) : ty = + ty_subst { f_subst_id with fs_v = Mid.map fst s } ty - let subst1 ((id, t) : ebinding) (ty : ty) : ty = + let subst1 ((id, t) : EcIdent.t * etyarg) (ty : ty) : ty = subst (Mid.singleton id t) ty - let init (lv : ident list) (lt : ty list) : ty Mid.t = - assert (List.length lv = List.length lt); - List.fold_left2 (fun s v t -> Mid.add v t s) Mid.empty lv lt + let init (l : (EcIdent.t * etyarg) list) : etyarg Mid.t = + List.fold_left (fun s (v, t) -> Mid.add v t s) Mid.empty l + + let subst_etyarg (s : etyarg Mid.t) ((ty, w) : etyarg) : etyarg = + (subst s ty, w) + + let subst_tc (s : etyarg Mid.t) (tc : typeclass) : typeclass = + { tc with tc_args = List.map (subst_etyarg s) tc.tc_args } - let f_subst ~(freshen : bool) (lv : ident list) (lt : ty list) : form -> form = - Fsubst.f_subst_tvar ~freshen (init lv lt) + let f_subst ~(freshen : bool) (l : (EcIdent.t * etyarg) list) : form -> form = + Fsubst.f_subst_tvar ~freshen (init l) end diff --git a/src/ecCoreSubst.mli b/src/ecCoreSubst.mli index f829b8d387..d8826d82b0 100644 --- a/src/ecCoreSubst.mli +++ b/src/ecCoreSubst.mli @@ -1,5 +1,4 @@ (* -------------------------------------------------------------------- *) -open EcUid open EcIdent open EcPath open EcAst @@ -26,28 +25,33 @@ type 'a subst_binder = f_subst -> 'a -> f_subst * 'a (* -------------------------------------------------------------------- *) val f_subst_init : ?freshen:bool - -> ?tu:ty Muid.t + -> ?tu:ty TyUni.Muid.t -> ?tv:ty Mid.t + -> ?tw:tcwitness list Mid.t + -> ?tw_uni:tcwitness TcUni.Muid.t -> ?esloc:expr Mid.t -> unit -> f_subst (* -------------------------------------------------------------------- *) module Tuni : sig - val univars : ty -> Suid.t - val subst1 : (uid * ty) -> f_subst - val subst : ty Muid.t -> f_subst - val subst_dom : ty Muid.t -> dom -> dom - val occurs : uid -> ty -> bool - val fv : ty -> Suid.t + val univars : ty -> TyUni.Suid.t + val subst1 : (tyuni * ty) -> f_subst + val subst : ?tw_uni:tcwitness TcUni.Muid.t -> ty TyUni.Muid.t -> f_subst + val subst_dom : ty TyUni.Muid.t -> dom -> dom + val occurs : tyuni -> ty -> bool + val fv : ty -> TyUni.Suid.t end (* -------------------------------------------------------------------- *) module Tvar : sig - val init : EcIdent.t list -> ty list -> ty Mid.t - val subst1 : (EcIdent.t * ty) -> ty -> ty - val subst : ty Mid.t -> ty -> ty - val f_subst : freshen:bool -> EcIdent.t list -> ty list -> form -> form + val init : (EcIdent.t * etyarg) list -> etyarg Mid.t + val subst1 : (EcIdent.t * etyarg) -> ty -> ty + val subst : etyarg Mid.t -> ty -> ty + val subst_etyarg : etyarg Mid.t -> etyarg -> etyarg + val subst_tc : etyarg Mid.t -> typeclass -> typeclass + + val f_subst : freshen:bool -> (EcIdent.t * etyarg) list -> form -> form end (* -------------------------------------------------------------------- *) @@ -55,11 +59,11 @@ val add_elocal : (EcIdent.t * ty) subst_binder val add_elocals : (EcIdent.t * ty) list subst_binder val bind_elocal : f_subst -> EcIdent.t -> expr -> f_subst - (* -------------------------------------------------------------------- *) -val ty_subst : ty substitute -val e_subst : expr substitute -val s_subst : stmt substitute +val ty_subst : ty substitute +val etyarg_subst : etyarg substitute +val e_subst : expr substitute +val s_subst : stmt substitute (* -------------------------------------------------------------------- *) module Fsubst : sig @@ -68,8 +72,10 @@ module Fsubst : sig val f_subst_init : ?freshen:bool - -> ?tu:ty Muid.t + -> ?tu:ty TyUni.Muid.t -> ?tv:ty Mid.t + -> ?tw:tcwitness list Mid.t + -> ?tw_uni:tcwitness TcUni.Muid.t -> ?esloc:expr Mid.t -> unit -> f_subst @@ -85,11 +91,7 @@ module Fsubst : sig val f_subst_local : EcIdent.t -> form -> form -> form val f_subst_mem : EcIdent.t -> EcIdent.t -> form -> form - - val f_subst_tvar : - freshen:bool -> - EcTypes.ty EcIdent.Mid.t -> - form -> form + val f_subst_tvar : freshen:bool -> etyarg Mid.t -> form -> form val add_binding : binding subst_binder val add_bindings : bindings subst_binder diff --git a/src/ecDecl.ml b/src/ecDecl.ml index 2dbd8b27c6..f713ae87af 100644 --- a/src/ecDecl.ml +++ b/src/ecDecl.ml @@ -5,13 +5,12 @@ open EcTypes open EcCoreFol module Sp = EcPath.Sp -module TC = EcTypeClass module BI = EcBigInt module Ssym = EcSymbols.Ssym module CS = EcCoreSubst (* -------------------------------------------------------------------- *) -type ty_param = EcIdent.t +type ty_param = EcIdent.t * typeclass list type ty_params = ty_param list type ty_pctor = [ `Int of int | `Named of ty_params ] @@ -27,33 +26,36 @@ type ty_dtype = { tydt_schcase : EcCoreFol.form; } -type ty_body = - | Concrete of EcTypes.ty - | Abstract - | Datatype of ty_dtype - | Record of ty_record +type ty_body = [ + | `Concrete of EcTypes.ty + | `Abstract of typeclass list + | `Datatype of ty_dtype + | `Record of ty_record +] type tydecl = { - tyd_params : ty_params; - tyd_type : ty_body; - tyd_loca : locality; + tyd_params : ty_params; + tyd_type : ty_body; + tyd_resolve : bool; + tyd_loca : locality; + tyd_subtype : (EcTypes.ty * EcCoreFol.form) option; } let tydecl_as_concrete (td : tydecl) = - match td.tyd_type with Concrete x -> Some x | _ -> None + match td.tyd_type with `Concrete x -> Some x | _ -> None let tydecl_as_abstract (td : tydecl) = - match td.tyd_type with Abstract -> Some () | _ -> None + match td.tyd_type with `Abstract x -> Some x | _ -> None let tydecl_as_datatype (td : tydecl) = - match td.tyd_type with Datatype x -> Some x | _ -> None + match td.tyd_type with `Datatype x -> Some x | _ -> None let tydecl_as_record (td : tydecl) = - match td.tyd_type with Record (x, y) -> Some (x, y) | _ -> None + match td.tyd_type with `Record x -> Some x | _ -> None (* -------------------------------------------------------------------- *) -let abs_tydecl ?(params = `Int 0) lc = +let abs_tydecl ?(resolve = true) ?(tc = []) ?(params = `Int 0) lc = let params = match params with | `Named params -> @@ -61,15 +63,27 @@ let abs_tydecl ?(params = `Int 0) lc = | `Int n -> let fmt = fun x -> Printf.sprintf "'%s" x in List.map - (fun x -> (EcIdent.create x)) + (fun x -> (EcIdent.create x, [])) (EcUid.NameGen.bulk ~fmt n) in - { tyd_params = params; tyd_type = Abstract; tyd_loca = lc; } + { tyd_params = params; + tyd_type = `Abstract tc; + tyd_resolve = resolve; + tyd_loca = lc; + tyd_subtype = None; } + +(* -------------------------------------------------------------------- *) +let etyargs_of_tparams (tps : ty_params) : etyarg list = + List.map (fun (a, tcs) -> + let ety = + List.mapi (fun offset _ -> TCIAbstract { support = `Var a; offset; lift = [] }) tcs + in (tvar a, ety) + ) tps (* -------------------------------------------------------------------- *) -let ty_instantiate (params : ty_params) (args : ty list) (ty : ty) = - let subst = CS.Tvar.init params args in +let ty_instanciate (params : ty_params) (args : etyarg list) (ty : ty) = + let subst = CS.Tvar.init (List.combine (List.map fst params) args) in CS.Tvar.subst subst ty (* -------------------------------------------------------------------- *) @@ -87,7 +101,7 @@ and opbody = | OP_Proj of EcPath.path * int * int | OP_Fix of opfix | OP_Exn of ty list - | OP_TC + | OP_TC of EcPath.path * string and prbody = | PR_Plain of form @@ -187,6 +201,11 @@ let is_rcrd op = | OB_oper (Some (OP_Record _)) -> true | _ -> false +let is_tc_op op = + match op.op_kind with + | OB_oper (Some (OP_TC _)) -> true + | _ -> false + let is_fix op = match op.op_kind with | OB_oper (Some (OP_Fix _)) -> true @@ -268,6 +287,11 @@ let operator_as_prind (op : operator) = | OB_pred (Some (PR_Ind pri)) -> pri | _ -> assert false +let operator_as_tc (op : operator) = + match op.op_kind with + | OB_oper (Some (OP_TC (tcpath, name))) -> (tcpath, name) + | _ -> assert false + let operator_as_exception (op : operator) = match op.op_kind with | OB_oper (Some (OP_Exn exn_dom)) -> @@ -279,47 +303,17 @@ let operator_of_exception (ex: exception_) = mk_op ~opaque: optransparent [] ty (Some (OP_Exn ex.exn_dom)) ex.exn_loca (* -------------------------------------------------------------------- *) -let axiomatized_op - ?(nargs = 0) - ?(nosmt = false) - (path : EcPath.path) - ((tparams, axbd) : ty_params * form) - (lc : locality) - : axiom -= - let axbd, axpm = - let bdpm = tparams in - let axpm = List.map EcIdent.fresh bdpm in - (CS.Tvar.f_subst ~freshen:true bdpm (List.map EcTypes.tvar axpm) axbd, - axpm) - in - - let args, axbd = - match axbd.f_node with - | Fquant (Llambda, bds, axbd) -> - let bds, flam = List.split_at nargs bds in - (bds, f_lambda flam axbd) - | _ -> [], axbd - in - - let opargs = List.map (fun (x, ty) -> f_local x (gty_as_ty ty)) args in - let tyargs = List.map EcTypes.tvar axpm in - let op = f_op path tyargs (toarrow (List.map f_ty opargs) axbd.EcAst.f_ty) in - let op = f_app op opargs axbd.f_ty in - let axspec = f_forall args (f_eq op axbd) in - - { ax_tparams = axpm; - ax_spec = axspec; - ax_kind = `Axiom (Ssym.empty, false); - ax_loca = lc; - ax_smt = not nosmt; } - -(* -------------------------------------------------------------------- *) -type typeclass = { - tc_prt : EcPath.path option; - tc_ops : (EcIdent.t * EcTypes.ty) list; - tc_axs : (EcSymbols.symbol * EcCoreFol.form) list; - tc_loca: is_local; +(* A parent typeclass plus an optional op renaming. The renaming maps + the parent's op names (recursively, including its own ancestors) + to op names declared in or inherited by the subclass — used to + project a subclass instance into a parent instance with different + operator names. Empty list = plain inheritance. *) +type tc_decl = { + tc_tparams : ty_params; + tc_prts : (typeclass * (EcSymbols.symbol * EcSymbols.symbol) list) list; + tc_ops : (EcIdent.t * EcTypes.ty) list; + tc_axs : (EcSymbols.symbol * EcCoreFol.form) list; + tc_loca : is_local; } (* -------------------------------------------------------------------- *) diff --git a/src/ecDecl.mli b/src/ecDecl.mli index db24d1950b..41b376acba 100644 --- a/src/ecDecl.mli +++ b/src/ecDecl.mli @@ -1,12 +1,13 @@ (* -------------------------------------------------------------------- *) open EcUtils +open EcAst open EcSymbols open EcBigInt open EcTypes open EcCoreFol (* -------------------------------------------------------------------- *) -type ty_param = EcIdent.t +type ty_param = EcIdent.t * typeclass list type ty_params = ty_param list type ty_pctor = [ `Int of int | `Named of ty_params ] @@ -22,27 +23,40 @@ type ty_dtype = { tydt_schcase : EcCoreFol.form; } -type ty_body = - | Concrete of EcTypes.ty - | Abstract - | Datatype of ty_dtype - | Record of ty_record +and ty_body = [ + | `Concrete of EcTypes.ty + | `Abstract of typeclass list + | `Datatype of ty_dtype + | `Record of ty_record +] type tydecl = { - tyd_params : ty_params; - tyd_type : ty_body; - tyd_loca : locality; + tyd_params : ty_params; + tyd_type : ty_body; + tyd_resolve : bool; + tyd_loca : locality; + (* For [subtype]-declared types: the carrier and the predicate. The + declared type itself stays [tyd_type = `Abstract []], because a + subtype is semantically a fresh abstract type — but its dependency + on free type variables (when declared inside a section) must be + visible to the section-close machinery. [tydecl_fv] unions the + carrier+predicate fv into the type's fv when this field is set, + so a subtype declared inside [section. declare type c <: tc.] gets + the section's tparams added at close, just like type aliases do. *) + tyd_subtype : (EcTypes.ty * EcCoreFol.form) option; } val tydecl_as_concrete : tydecl -> EcTypes.ty option -val tydecl_as_abstract : tydecl -> unit option +val tydecl_as_abstract : tydecl -> typeclass list option val tydecl_as_datatype : tydecl -> ty_dtype option -val tydecl_as_record : tydecl -> (form * (EcSymbols.symbol * EcTypes.ty) list) option +val tydecl_as_record : tydecl -> ty_record option + +val abs_tydecl : ?resolve:bool -> ?tc:typeclass list -> ?params:ty_pctor -> locality -> tydecl -val abs_tydecl : ?params:ty_pctor -> locality -> tydecl +val etyargs_of_tparams : ty_params -> etyarg list -val ty_instantiate : ty_params -> ty list -> ty -> ty +val ty_instanciate : ty_params -> etyarg list -> ty -> ty (* -------------------------------------------------------------------- *) type exception_ = { @@ -67,7 +81,7 @@ and opbody = | OP_Proj of EcPath.path * int * int | OP_Fix of opfix | OP_Exn of ty list - | OP_TC + | OP_TC of EcPath.path * string and prbody = | PR_Plain of form @@ -126,6 +140,7 @@ val is_oper : operator -> bool val is_ctor : operator -> bool val is_proj : operator -> bool val is_rcrd : operator -> bool +val is_tc_op : operator -> bool val is_fix : operator -> bool val is_abbrev : operator -> bool val is_prind : operator -> bool @@ -145,6 +160,7 @@ val operator_as_rcrd : operator -> EcPath.path val operator_as_proj : operator -> EcPath.path * int * int val operator_as_fix : operator -> opfix val operator_as_prind : operator -> prind +val operator_as_tc : operator -> EcPath.path * string val operator_as_exception : operator -> exception_ val operator_of_exception : exception_ -> operator @@ -165,20 +181,12 @@ val is_axiom : axiom_kind -> bool val is_lemma : axiom_kind -> bool (* -------------------------------------------------------------------- *) -val axiomatized_op : - ?nargs: int - -> ?nosmt:bool - -> EcPath.path - -> (ty_params * form) - -> locality - -> axiom - -(* -------------------------------------------------------------------- *) -type typeclass = { - tc_prt : EcPath.path option; - tc_ops : (EcIdent.t * EcTypes.ty) list; - tc_axs : (EcSymbols.symbol * form) list; - tc_loca: is_local; +type tc_decl = { + tc_tparams : ty_params; + tc_prts : (typeclass * (EcSymbols.symbol * EcSymbols.symbol) list) list; + tc_ops : (EcIdent.t * EcTypes.ty) list; + tc_axs : (EcSymbols.symbol * EcCoreFol.form) list; + tc_loca : is_local; } (* -------------------------------------------------------------------- *) diff --git a/src/ecEnv.ml b/src/ecEnv.ml index 8a449a7797..8d8f91d05e 100644 --- a/src/ecEnv.ml +++ b/src/ecEnv.ml @@ -18,8 +18,8 @@ module Msym = EcSymbols.Msym module Mp = EcPath.Mp module Sid = EcIdent.Sid module Mid = EcIdent.Mid -module TC = EcTypeClass module Mint = EcMaps.Mint +module Mstr = EcMaps.Mstr (* -------------------------------------------------------------------- *) type 'a suspension = { @@ -89,7 +89,8 @@ type mc = { mc_operators : (ipath * EcDecl.operator) MMsym.t; mc_axioms : (ipath * EcDecl.axiom) MMsym.t; mc_theories : (ipath * ctheory) MMsym.t; - mc_typeclasses: (ipath * typeclass) MMsym.t; + mc_typeclasses: (ipath * tc_decl) MMsym.t; + mc_tcinstances: (ipath * tcinstance) MMsym.t; mc_rwbase : (ipath * path) MMsym.t; mc_components : ipath MMsym.t; } @@ -183,8 +184,7 @@ type preenv = { env_memories : EcMemory.memtype Mmem.t; env_actmem : actmem option; env_abs_st : EcModules.abs_uses Mid.t; - env_tci : ((ty_params * ty) * tcinstance) list; - env_tc : TC.graph; + env_tci : (path option * tcinstance) list; env_rwbase : Sp.t Mip.t; env_atbase : atbase Msym.t; env_redbase : mredinfo; @@ -211,12 +211,6 @@ and scope = [ | `Fun of EcPath.xpath ] -and tcinstance = [ - | `Ring of EcDecl.ring - | `Field of EcDecl.field - | `General of EcPath.path -] - and redinfo = { ri_priomap : (EcTheory.rule list) Mint.t; ri_list : (EcTheory.rule list) Lazy.t; } @@ -282,6 +276,7 @@ let empty_mc params = { mc_variables = MMsym.empty; mc_functions = MMsym.empty; mc_typeclasses= MMsym.empty; + mc_tcinstances= MMsym.empty; mc_rwbase = MMsym.empty; mc_components = MMsym.empty; } @@ -313,7 +308,6 @@ let empty gstate = env_actmem = None; env_abs_st = Mid.empty; env_tci = []; - env_tc = TC.Graph.empty; env_rwbase = Mip.empty; env_atbase = Msym.empty; env_redbase = Mrd.empty; @@ -512,12 +506,13 @@ module MC = struct | IPIdent _ -> assert false | IPPath p -> p - let _downpath_for_tydecl = _downpath_for_th - let _downpath_for_modsig = _downpath_for_th - let _downpath_for_operator = _downpath_for_th - let _downpath_for_axiom = _downpath_for_th - let _downpath_for_typeclass = _downpath_for_th - let _downpath_for_rwbase = _downpath_for_th + let _downpath_for_tydecl = _downpath_for_th + let _downpath_for_modsig = _downpath_for_th + let _downpath_for_operator = _downpath_for_th + let _downpath_for_axiom = _downpath_for_th + let _downpath_for_typeclass = _downpath_for_th + let _downpath_for_tcinstance = _downpath_for_th + let _downpath_for_rwbase = _downpath_for_th (* ------------------------------------------------------------------ *) let _params_of_path p env = @@ -789,16 +784,16 @@ module MC = struct let loca = tyd.tyd_loca in match tyd.tyd_type with - | Concrete _ -> mc - | Abstract -> mc + | `Concrete _ -> mc + | `Abstract _ -> mc - | Datatype dtype -> + | `Datatype dtype -> let cs = dtype.tydt_ctors in let schelim = dtype.tydt_schelim in let schcase = dtype.tydt_schcase in - let params = List.map tvar tyd.tyd_params in + let params = etyargs_of_tparams tyd.tyd_params in let for1 i (c, aty) = - let aty = EcTypes.toarrow aty (tconstr mypath params) in + let aty = toarrow aty (tconstr_tc mypath params) in let aty = EcSubst.freshen_type (tyd.tyd_params, aty) in let cop = mk_op ~opaque:optransparent (fst aty) (snd aty) @@ -836,12 +831,12 @@ module MC = struct _up_operator candup mc name (ipath name, op) ) mc projs - | Record (scheme, fields) -> - let params = List.map tvar tyd.tyd_params in + | `Record (scheme, fields) -> + let params = etyargs_of_tparams tyd.tyd_params in let nfields = List.length fields in let cfields = let for1 i (f, aty) = - let aty = EcTypes.tfun (tconstr mypath params) aty in + let aty = tfun (tconstr_tc mypath params) aty in let aty = EcSubst.freshen_type (tyd.tyd_params, aty) in let fop = mk_op ~opaque:optransparent (fst aty) (snd aty) (Some (OP_Proj (mypath, i, nfields))) loca in @@ -862,7 +857,7 @@ module MC = struct let stname = Printf.sprintf "mk_%s" x in let stop = - let stty = toarrow (List.map snd fields) (tconstr mypath params) in + let stty = toarrow (List.map snd fields) (tconstr_tc mypath params) in let stty = EcSubst.freshen_type (tyd.tyd_params, stty) in mk_op ~opaque:optransparent (fst stty) (snd stty) (Some (OP_Record mypath)) loca in @@ -912,24 +907,31 @@ module MC = struct let self = EcIdent.create "'self" in - let tsubst =EcSubst.add_tydef EcSubst.empty mypath ([], tvar self) in + let tsubst =EcSubst.add_tydef EcSubst.empty mypath ([], tvar self, []) in let operators = let on1 (opid, optype) = let opname = EcIdent.name opid in let optype = EcSubst.subst_ty tsubst optype in - let opdecl = - mk_op ~opaque:optransparent [(self)] - optype (Some OP_TC) loca - in (opid, xpath opname, optype, opdecl) + let tcargs = etyargs_of_tparams tc.tc_tparams in + let opargs = (self, [{tc_name = mypath; tc_args = tcargs;}]) in + let opargs = tc.tc_tparams @ [opargs] in + let opdecl = OP_TC (mypath, opname) in + let opdecl = mk_op ~opaque:optransparent opargs optype (Some opdecl) loca in + (opid, xpath opname, optype, opdecl) in List.map on1 tc.tc_ops in let fsubst = + let op_etyargs = + let tparams = + tc.tc_tparams + @ [(self, [{tc_name = mypath; tc_args = etyargs_of_tparams tc.tc_tparams}])] + in EcDecl.etyargs_of_tparams tparams in List.fold_left (fun s (x, xp, xty, _) -> - let fop = EcCoreFol.f_op xp [tvar self] xty in + let fop = EcCoreFol.f_op_tc xp op_etyargs xty in EcSubst.add_flocal s x fop) tsubst operators @@ -938,8 +940,11 @@ module MC = struct let axioms = List.map (fun (x, ax) -> + let tcargs = etyargs_of_tparams tc.tc_tparams in + let axargs = (self, [{tc_name = mypath; tc_args = tcargs}]) in + let axargs = tc.tc_tparams @ [axargs] in let ax = EcSubst.subst_form fsubst ax in - (x, { ax_tparams = [(self)]; + (x, { ax_tparams = axargs; ax_spec = ax; ax_kind = `Lemma; ax_loca = loca; @@ -963,6 +968,20 @@ module MC = struct let import_typeclass p ax env = import (_up_typeclass true) (IPPath p) ax env + (* -------------------------------------------------------------------- *) + let lookup_tcinstance qnx env = + match lookup (fun mc -> mc.mc_tcinstances) qnx env with + | None -> lookup_error (`QSymbol qnx) + | Some (p, (args, obj)) -> (_downpath_for_tcinstance env p args, obj) + + let _up_tcinstance candup mc x obj= + if not candup && MMsym.last x mc.mc_tcinstances <> None then + raise (DuplicatedBinding x); + { mc with mc_tcinstances = MMsym.add x obj mc.mc_tcinstances } + + let import_tcinstance p tci env = + import (_up_tcinstance true) (IPPath p) tci env + (* -------------------------------------------------------------------- *) let lookup_rwbase qnx env = match lookup (fun mc -> mc.mc_rwbase) qnx env with @@ -1115,6 +1134,16 @@ module MC = struct else (add2mc _up_theory xsubth cth mc, None) + | Th_typeclass (x, tc) -> + (add2mc _up_typeclass x tc mc, None) + + | Th_instance (x, tci) -> + let mc = + x |> Option.fold + ~none:mc + ~some:(fun x -> add2mc _up_tcinstance x tci mc) + in (mc, None) + | Th_baserw (x, _) -> (add2mc _up_rwbase x (expath x) mc, None) @@ -1122,8 +1151,7 @@ module MC = struct (* FIXME:ALIAS *) (mc, None) - | Th_export _ | Th_addrw _ | Th_instance _ - | Th_auto _ | Th_reduction _ -> + | Th_export _ | Th_addrw _ | Th_auto _ | Th_reduction _ -> (mc, None) in @@ -1202,6 +1230,9 @@ module MC = struct and bind_typeclass x tc env = bind _up_typeclass x tc env + and bind_tcinstance x tci env = + bind _up_tcinstance x tci env + and bind_rwbase x p env = bind _up_rwbase x p env end @@ -1387,7 +1418,7 @@ let gen_all fmc flk ?(check = fun _ _ -> true) ?name (env : env) = (* ------------------------------------------------------------------ *) module TypeClass = struct - type t = typeclass + type t = tc_decl let by_path_opt (p : EcPath.path) (env : env) = omap @@ -1400,39 +1431,77 @@ module TypeClass = struct | Some obj -> obj let add (p : EcPath.path) (env : env) = - let obj = by_path p env in - MC.import_typeclass p obj env + MC.import_typeclass p (by_path p env) env - let rebind name tc env = - let env = MC.bind_typeclass name tc env in - match tc.tc_prt with - | None -> env - | Some prt -> - let myself = EcPath.pqname (root env) name in - { env with env_tc = TC.Graph.add ~src:myself ~dst:prt env.env_tc } + let rebind (name : symbol) (tc : t) (env : env) = + MC.bind_typeclass name tc env - let lookup qname (env : env) = + let bind ?(import = true) (name : symbol) (tc : t) (env : env) = + let env = if import then rebind name tc env else env in + { env with + env_item = mkitem ~import (Th_typeclass (name, tc)) :: env.env_item } + + let lookup (qname : qsymbol) (env : env) = MC.lookup_typeclass qname env - let lookup_opt name env = + let lookup_opt (name : qsymbol) (env : env) = try_lf (fun () -> lookup name env) - let lookup_path name env = + let lookup_path (name : qsymbol) (env : env) = fst (lookup name env) +end - let graph (env : env) = - env.env_tc +(* ------------------------------------------------------------------ *) +module TcInstance = struct + type t = tcinstance + + let by_path_opt (p : EcPath.path) (env : env) = + omap + check_not_suspended + (MC.by_path (fun mc -> mc.mc_tcinstances) (IPPath p) env) - let bind_instance ty cr tci = - (ty, cr) :: tci + let by_path (p : EcPath.path) (env : env) = + match by_path_opt p env with + | None -> lookup_error (`Path p) + | Some obj -> obj - let add_instance ?(import = true) ty cr lc env = - let item = Th_instance (ty, cr, lc) in + let add (p : EcPath.path) (env : env) = + MC.import_tcinstance p (by_path p env) env + + let bind_instance (path : path option) (tci : t) (env : _) = + (path, tci) :: env + + let rebind (name : symbol option) (tci : t) (env : env) = + let env = + name |> Option.fold ~none:env ~some:(fun name -> + MC.bind_tcinstance name tci env) + in + let path = + Option.map + (fun name -> EcPath.pqname (root env) name) + name + in { env with env_tci = bind_instance path tci env.env_tci } + + let bind ?(import = true) (name : symbol option) (tci : t) (env : env) = + let env = + if import then rebind name tci env else env in { env with - env_tci = bind_instance ty cr env.env_tci; - env_item = mkitem ~import item :: env.env_item; } + env_item = mkitem ~import (Th_instance (name, tci)) :: env.env_item } + + let lookup qname (env : env) = + MC.lookup_tcinstance qname env - let get_instances env = env.env_tci + let lookup_opt (name : qsymbol) (env : env) = + try_lf (fun () -> lookup name env) + + let lookup_path (name : qsymbol) (env : env) = + fst (lookup name env) + + let get_instances (env : env) = + env.env_tci + + let get_all (env : env) : (path option * t) list = + env.env_tci end (* -------------------------------------------------------------------- *) @@ -2528,7 +2597,7 @@ module Ty = struct let add (p : EcPath.path) (env : env) = let obj = by_path p env in - MC.import_tydecl p obj env + MC.import_tydecl p obj env let lookup ?unique (qname : qsymbol) (env : env) = MC.lookup_tydecl ?unique qname env @@ -2541,14 +2610,14 @@ module Ty = struct let defined (name : EcPath.path) (env : env) = match by_path_opt name env with - | Some { tyd_type = Concrete _ } -> true + | Some { tyd_type = `Concrete _ } -> true | _ -> false - let unfold (name : EcPath.path) (args : EcTypes.ty list) (env : env) = + let unfold (name : EcPath.path) (args : etyarg list) (env : env) = match by_path_opt name env with - | Some ({ tyd_type = Concrete body } as tyd) -> + | Some ({ tyd_type = `Concrete body } as tyd) -> Tvar.subst - (Tvar.init tyd.tyd_params args) + (Tvar.init (List.combine (List.map fst tyd.tyd_params) args)) body | _ -> raise (LookupFailure (`Path name)) @@ -2557,13 +2626,11 @@ module Ty = struct | Tconstr (p, tys) when defined p env -> hnorm (unfold p tys env) env | _ -> ty - let rec ty_hnorm (ty : ty) (env : env) = match ty.ty_node with | Tconstr (p, tys) when defined p env -> ty_hnorm (unfold p tys env) env | _ -> ty - let rec decompose_fun (ty : ty) (env : env) : dom * ty = match (hnorm ty env).ty_node with | Tfun (ty1, ty2) -> @@ -2582,14 +2649,14 @@ module Ty = struct match ty.ty_node with | Tconstr (p, tys) -> begin match by_path_opt p env with - | Some ({ tyd_type = (Datatype _ | Record _) as body }) -> + | Some ({ tyd_type = (`Datatype _ | `Record _) as body }) -> let prefix = EcPath.prefix p in let basename = EcPath.basename p in let basename = match body, mode with - | Record _, (`Ind | `Case) -> basename ^ "_ind" - | Datatype _, `Ind -> basename ^ "_ind" - | Datatype _, `Case -> basename ^ "_case" + | `Record _, (`Ind | `Case) -> basename ^ "_ind" + | `Datatype _, `Ind -> basename ^ "_ind" + | `Datatype _, `Case -> basename ^ "_case" | _, _ -> assert false in Some (EcPath.pqoname prefix basename, tys) @@ -2602,11 +2669,11 @@ module Ty = struct | Tconstr (p, tys) -> Some (p, oget (by_path_opt p env), tys) | _ -> None - let rebind name ty env = - MC.bind_tydecl name ty env + let rebind (name : symbol) (tyd : t) (env : env) = + MC.bind_tydecl name tyd env let bind ?(import = true) name ty env = - let env = rebind name ty env in + let env = if import then rebind name ty env else env in { env with env_item = mkitem ~import (Th_type (name, ty)) :: env.env_item } @@ -2678,7 +2745,6 @@ module Op = struct let core_reduce ?(mode = `IfTransparent) ?(nargs = 0) env p = let op = oget (by_path_opt p env) in - match op.op_kind with | OB_oper (Some (OP_Plain f)) | OB_pred (Some (PR_Plain f)) -> begin @@ -2706,8 +2772,215 @@ module Op = struct else false let reduce ?mode ?nargs env p tys = - let op, f = core_reduce ?mode ?nargs env p in - Tvar.f_subst ~freshen:true op.op_tparams tys f + let op, form = core_reduce ?mode ?nargs env p in + Tvar.f_subst ~freshen:true + (List.combine (List.map fst op.op_tparams) tys) + form + + let tc_core_reduce (env : env) (p : path) (tys : etyarg list) = + let op = by_path p env in + + if not (is_tc_op op) then + raise NotReducible; + + (* Last type application if the TC parameter. We extract the type-class * + * information from the witness. *) + let _, (_, tcw) = List.betail tys in + + match as_seq1 tcw with + | TCIConcrete { path = tcipath; etyargs = tciargs; lift } -> begin + let tci = TcInstance.by_path tcipath env in + + (* The witness's [lift] is a path through the parent DAG: each + element selects which parent edge to take. We follow it via + [tci_parents] (the synthesised parent instance paths). For + single-parent classes the path is always all-zeros; for + multi-parent (factory) classes the path encodes which + parent is taken at each step. + + Fallback when [tci_parents] is empty (manually-declared + instance with no synthesis tracking): walk the TC parent + chain naively and search the database for a matching + ancestor instance. This loses path-disambiguation but + covers the legacy single-parent case. *) + let resolve_lifted () = + if lift = [] then None + else + let rec walk tci = function + | [] -> Some tci + | i :: rest -> + match List.nth_opt tci.tci_parents i with + | None -> None + | Some parent_path -> + let parent_tci = TcInstance.by_path parent_path env in + walk parent_tci rest + in + match walk tci lift with + | Some target_tci -> begin + match target_tci.tci_instance with + | `General (_, Some sym) -> Some (target_tci, sym) + | _ -> None + end + | None -> + (* Fallback: walk the TC parent chain (taking parent #0 + at each step — equivalent to the all-zeros path) and + search the database for the matching ancestor instance + on the same carrier. *) + let walk_up_tc (tc : typeclass) (path : int list) : typeclass option = + let rec aux tc = function + | [] -> Some tc + | i :: rest -> + let decl = TypeClass.by_path tc.tc_name env in + match List.nth_opt decl.tc_prts i with + | None -> None + | Some (parent, _ren) -> + let subst = + List.fold_left2 + (fun s (a, _) etyarg -> Mid.add a etyarg s) + Mid.empty decl.tc_tparams tc.tc_args in + let parent = EcCoreSubst.Tvar.subst_tc subst parent in + aux parent rest + in aux tc path in + match tci.tci_instance with + | `General (tgp, _) -> begin + match walk_up_tc tgp lift with + | None -> None + | Some target -> + let carrier = tci.tci_type in + List.fold_left (fun acc (_, tci_existing) -> + match acc with + | Some _ -> acc + | None -> + match tci_existing.tci_instance with + | `General (tgp', Some sym) + when EcPath.p_equal tgp'.tc_name target.tc_name + && EcTypes.ty_equal tci_existing.tci_type carrier -> + Some (tci_existing, sym) + | _ -> None) + None (TcInstance.get_all env) + end + | _ -> None in + + match resolve_lifted () with + | Some (tci_target, symbols) -> + (EcDecl.operator_as_tc op, + (tciargs, (tci_target.tci_params, symbols))) + | None -> + match tci.tci_instance with + | `General (_, Some symbols) -> + (EcDecl.operator_as_tc op, (tciargs, (tci.tci_params, symbols))) + | _ -> raise NotReducible + end + + | _ -> + raise NotReducible + + (* Try to unfold a TC op via a factory rename when the witness is + [TCIAbstract]. Walks the parent-DAG path defined by [(offset, lift)] + looking for an edge whose rename maps [basename p] to a different + name. If found, returns the renamed op-path with a witness lifted + up to (but not including) that edge. The renamed op is a class op + declared in the child class of the renaming edge. *) + let tc_reduce_abstract_via_rename + (env : env) (p : path) (tys : etyarg list) + : form option + = + match by_path_opt p env with + | None -> None + | Some op when not (EcDecl.is_tc_op op) -> None + | Some _op -> + let prefix_tys, last_ety = List.betail tys in + let _, tcws = last_ety in + match tcws with + | [TCIAbstract { support; offset; lift }] -> begin + let opname = EcPath.basename p in + let tcs_opt = + match support with + | `Abs ap -> begin + match Ty.by_path_opt ap env with + | Some { tyd_type = `Abstract tcs; _ } -> Some tcs + | _ -> None + end + | `Var _ -> None in + match tcs_opt with + | None -> None + | Some tcs -> + if offset >= List.length tcs then None + else + let start = List.nth tcs offset in + let rec walk cur acc_lift_rev = function + | [] -> None + | i :: rest -> + let decl = TypeClass.by_path cur.tc_name env in + match List.nth_opt decl.tc_prts i with + | None -> None + | Some (parent, edge_ren) -> + let subst = + List.fold_left2 + (fun s (a, _) ety -> Mid.add a ety s) + Mid.empty decl.tc_tparams cur.tc_args in + let parent = EcCoreSubst.Tvar.subst_tc subst parent in + match walk parent (i :: acc_lift_rev) rest with + | Some _ as r -> r + | None -> + match List.assoc_opt opname edge_ren with + | Some new_name when new_name <> opname -> + Some (cur, new_name, List.rev acc_lift_rev) + | _ -> None in + match walk start [] lift with + | None -> None + | Some (cur_class, new_name, new_lift) -> + (* Class ops live at the THEORY level (sibling of the class + declaration), not under the class's own path. Strip one + component and append the renamed op-name. *) + let theory_prefix = + match EcPath.prefix cur_class.tc_name with + | Some pr -> pr + | None -> cur_class.tc_name in + let new_path = EcPath.pqname theory_prefix new_name in + match by_path_opt new_path env with + | None -> None + | Some new_op -> + let new_witness = + TCIAbstract { support; offset; lift = new_lift } in + let new_etyargs = + prefix_tys @ [(fst last_ety, [new_witness])] in + let tysubst = + EcCoreSubst.Tvar.init + (List.combine + (List.map fst new_op.op_tparams) + new_etyargs) in + let new_ty = EcCoreSubst.Tvar.subst tysubst new_op.op_ty in + Some (f_op_tc new_path new_etyargs new_ty) + end + | _ -> None + + let tc_reducible (env : env) (p : path) (tys : etyarg list) = + try ignore (tc_core_reduce env p tys); true + with NotReducible -> + Option.is_some (tc_reduce_abstract_via_rename env p tys) + + let tc_reduce (env : env) (p : path) (tys : etyarg list) = + try + let ((_, opname), (tciargs, (tciparams, symbols))) = + tc_core_reduce env p tys in + let subst = + List.fold_left + (fun subst (a, ety) -> + let ety = EcSubst.subst_etyarg subst ety in + EcSubst.add_tyvar subst a ety) + EcSubst.empty + (List.combine (List.map fst tciparams) tciargs) + in + let optg, opargs = EcMaps.Mstr.find opname symbols in + let opargs = List.map (EcSubst.subst_etyarg subst) opargs in + let optg_decl = by_path optg env in + let tysubst = Tvar.init (List.combine (List.map fst optg_decl.op_tparams) opargs) in + f_op_tc optg opargs (Tvar.subst tysubst optg_decl.op_ty) + with NotReducible -> + match tc_reduce_abstract_via_rename env p tys with + | Some f -> f + | None -> raise NotReducible let is_projection env p = try EcDecl.is_proj (by_path p env) @@ -2717,6 +2990,23 @@ module Op = struct try EcDecl.is_rcrd (by_path p env) with LookupFailure _ -> false + let is_tc_op env p = + try EcDecl.is_tc_op (by_path p env) + with LookupFailure _ -> false + + let tc_op_realised_by (env : env) (tcop : path) (concrete : path) = + if not (is_tc_op env tcop) then false + else + let tcop_basename = EcPath.basename tcop in + List.exists (fun (_, tci) -> + match tci.EcTheory.tci_instance with + | `General (_, Some sym) -> + (match EcMaps.Mstr.find_opt tcop_basename sym with + | Some (p, _) -> EcPath.p_equal p concrete + | None -> false) + | _ -> false) + (TcInstance.get_all env) + let is_dtype_ctor ?nargs env p = try match (by_path p env).op_kind with @@ -2825,10 +3115,10 @@ module Ax = struct let rebind name ax env = MC.bind_axiom name ax env - let instantiate p tys env = + let instanciate p tys env = match by_path_opt p env with | Some ({ ax_spec = f } as ax) -> - Tvar.f_subst ~freshen:true ax.ax_tparams tys f + Tvar.f_subst ~freshen:true (List.combine (List.map fst ax.ax_tparams) tys) f | _ -> raise (LookupFailure (`Path p)) let iter ?name f (env : env) = @@ -2838,22 +3128,6 @@ module Ax = struct gen_all (fun mc -> mc.mc_axioms) MC.lookup_axioms ?check ?name env end -(* -------------------------------------------------------------------- *) -module Algebra = struct - let bind_ring ty cr env = - assert (Mid.is_empty ty.ty_fv); - { env with env_tci = - TypeClass.bind_instance ([], ty) (`Ring cr) env.env_tci } - - let bind_field ty cr env = - assert (Mid.is_empty ty.ty_fv); - { env with env_tci = - TypeClass.bind_instance ([], ty) (`Field cr) env.env_tci } - - let add_ring ty cr lc env = TypeClass.add_instance ([], ty) (`Ring cr) lc env - let add_field ty cr lc env = TypeClass.add_instance ([], ty) (`Field cr) lc env -end - (* -------------------------------------------------------------------- *) module Theory = struct type t = ctheory @@ -2937,8 +3211,8 @@ module Theory = struct let xpath x = EcPath.pqname path x in match item.ti_item with - | Th_instance (ty, k, _) -> - TypeClass.bind_instance ty k inst + | Th_instance (name, tci) -> + TcInstance.bind_instance (Option.map xpath name) tci inst | Th_theory (x, cth) when cth.cth_mode = `Concrete -> bind_instance_th (xpath x) inst cth.cth_items @@ -2963,6 +3237,16 @@ module Theory = struct end | _ -> odfl base (tx path base item.ti_item) + (* ------------------------------------------------------------------ *) + let bind_tc_th = + let for1 _path base = function + | Th_typeclass (_, tc) -> + Some (tc :: base) + + | _ -> None + + in bind_base_th for1 + (* ------------------------------------------------------------------ *) let bind_br_th = let for1 path base = function @@ -3040,9 +3324,7 @@ module Theory = struct let env_ntbase = bind_nt_th thname env.env_ntbase items in let env_redbase = bind_rd_th thname env.env_redbase items in let env = - { env with - env_tci ; env_rwbase; - env_atbase; env_ntbase; env_redbase; } + { env with env_tci; env_rwbase; env_atbase; env_ntbase; env_redbase; } in add_restr_th thname env items @@ -3099,6 +3381,9 @@ module Theory = struct | Th_alias (name, path) -> rebind_alias name path env + | Th_typeclass (x, tc) -> + MC.import_typeclass (xpath x) tc env + | Th_addrw _ | Th_instance _ | Th_auto _ | Th_reduction _ -> env diff --git a/src/ecEnv.mli b/src/ecEnv.mli index 04354a391d..dadb519d01 100644 --- a/src/ecEnv.mli +++ b/src/ecEnv.mli @@ -180,7 +180,7 @@ module Ax : sig val iter : ?name:qsymbol -> (path -> t -> unit) -> env -> unit val all : ?check:(path -> t -> bool) -> ?name:qsymbol -> env -> (path * t) list - val instantiate : path -> EcTypes.ty list -> env -> form + val instanciate : path -> etyarg list -> env -> form end (* -------------------------------------------------------------------- *) @@ -328,11 +328,26 @@ module Op : sig val bind : ?import:bool -> symbol -> operator -> env -> env val reducible : ?mode:redmode -> ?nargs:int -> env -> path -> bool - val reduce : ?mode:redmode -> ?nargs:int -> env -> path -> ty list -> form + val reduce : ?mode:redmode -> ?nargs:int -> env -> path -> etyarg list -> form + + val tc_reducible : env -> path -> etyarg list -> bool + val tc_reduce : env -> path -> etyarg list -> form + + (* [tc_op_realised_by env tcop concrete] is true iff [tcop] is a + TC-class op and there exists a registered instance whose + symbol-map binds [tcop]'s basename to [concrete]. Used by the + matcher to bridge a pattern with a TC-op head whose carrier + is still a univar to a goal whose head is the registered + realisation, so e.g. [rewrite mul0r] (no TVI) matches goals + containing the structural [polyM]. The lookup is purely + syntactic — the caller must still post the carrier-pinning + unification that makes the bridge sound. *) + val tc_op_realised_by : env -> path -> path -> bool val is_projection : env -> path -> bool val is_record_ctor : env -> path -> bool val is_dtype_ctor : ?nargs:int -> env -> path -> bool + val is_tc_op : env -> path -> bool val is_fix_def : env -> path -> bool val is_abbrev : env -> path -> bool val is_prind : env -> path -> bool @@ -362,16 +377,15 @@ module Ty : sig val bind : ?import:bool -> symbol -> t -> env -> env val defined : path -> env -> bool - val unfold : path -> EcTypes.ty list -> env -> EcTypes.ty - val hnorm : EcTypes.ty -> env -> EcTypes.ty - val decompose_fun : EcTypes.ty -> env -> EcTypes.dom * EcTypes.ty + val unfold : path -> etyarg list -> env -> ty + val hnorm : ty -> env -> ty + val decompose_fun : ty -> env -> EcTypes.dom * ty val get_top_decl : - EcTypes.ty -> env -> (path * EcDecl.tydecl * EcTypes.ty list) option - + EcTypes.ty -> env -> (path * EcDecl.tydecl * etyarg list) option val scheme_of_ty : - [`Ind | `Case] -> EcTypes.ty -> env -> (path * EcTypes.ty list) option + [`Ind | `Case] -> EcTypes.ty -> env -> (path * etyarg list) option val signature : env -> ty -> ty list * ty @@ -382,17 +396,26 @@ end val ty_hnorm : ty -> env -> ty (* -------------------------------------------------------------------- *) -module Algebra : sig - val add_ring : ty -> EcDecl.ring -> is_local -> env -> env - val add_field : ty -> EcDecl.field -> is_local -> env -> env +module TypeClass : sig + type t = tc_decl + + val add : path -> env -> env + val bind : ?import:bool -> symbol -> t -> env -> env + val rebind : symbol -> t -> env -> env + + val by_path : path -> env -> t + val by_path_opt : path -> env -> t option + val lookup : qsymbol -> env -> path * t + val lookup_opt : qsymbol -> env -> (path * t) option + val lookup_path : qsymbol -> env -> path end (* -------------------------------------------------------------------- *) -module TypeClass : sig - type t = typeclass +module TcInstance : sig + type t = tcinstance - val add : path -> env -> env - val graph : env -> EcTypeClass.graph + val add : path -> env -> env + val bind : ?import:bool -> symbol option -> t -> env -> env val by_path : path -> env -> t val by_path_opt : path -> env -> t option @@ -400,8 +423,7 @@ module TypeClass : sig val lookup_opt : qsymbol -> env -> (path * t) option val lookup_path : qsymbol -> env -> path - val add_instance : ?import:bool -> (ty_params * ty) -> tcinstance -> is_local -> env -> env - val get_instances : env -> ((ty_params * ty) * tcinstance) list + val get_all : env -> (path option * t) list end (* -------------------------------------------------------------------- *) diff --git a/src/ecFol.ml b/src/ecFol.ml index 7a9fbf4942..8a97a1e6b4 100644 --- a/src/ecFol.ml +++ b/src/ecFol.ml @@ -191,8 +191,7 @@ let f_mu_x f1 f2 = let proj_distr_ty env ty = match (EcEnv.Ty.hnorm ty env).ty_node with - | Tconstr(_,lty) when List.length lty = 1 -> - List.hd lty + | Tconstr(_, [lty, []]) -> lty | _ -> assert false let f_mu env f1 f2 = @@ -854,7 +853,7 @@ type sform = | SFimp of form * form | SFiff of form * form | SFeq of form * form - | SFop of (EcPath.path * ty list) * (form list) + | SFop of (EcPath.path * etyarg list) * (form list) | SFhoareF of sHoareF | SFhoareS of sHoareS diff --git a/src/ecFol.mli b/src/ecFol.mli index 6be1d1aafc..787c877f38 100644 --- a/src/ecFol.mli +++ b/src/ecFol.mli @@ -226,7 +226,7 @@ type sform = | SFimp of form * form | SFiff of form * form | SFeq of form * form - | SFop of (path * ty list) * (form list) + | SFop of (path * etyarg list) * (form list) | SFhoareF of sHoareF | SFhoareS of sHoareS diff --git a/src/ecHiGoal.ml b/src/ecHiGoal.ml index 93389a275f..65bf29f180 100644 --- a/src/ecHiGoal.ml +++ b/src/ecHiGoal.ml @@ -114,15 +114,16 @@ let process_simplify_info ri (tc : tcenv1) = in { - EcReduction.beta = ri.pbeta; - EcReduction.delta_p = delta_p; - EcReduction.delta_h = delta_h; - EcReduction.zeta = ri.pzeta; - EcReduction.iota = ri.piota; - EcReduction.eta = ri.peta; - EcReduction.logic = if ri.plogic then Some `Full else None; - EcReduction.modpath = ri.pmodpath; - EcReduction.user = ri.puser; + EcReduction.beta = ri.pbeta; + EcReduction.delta_p = delta_p; + EcReduction.delta_h = delta_h; + EcReduction.delta_tc = ri.pdeltatc; + EcReduction.zeta = ri.pzeta; + EcReduction.iota = ri.piota; + EcReduction.eta = ri.peta; + EcReduction.logic = if ri.plogic then Some `Full else None; + EcReduction.modpath = ri.pmodpath; + EcReduction.user = ri.puser; } (*-------------------------------------------------------------------- *) @@ -571,7 +572,7 @@ let process_exacttype qs (tc : tcenv1) = tc_error !!tc "%a" EcEnv.pp_lookup_failure cause in let tys = - List.map (fun a -> EcTypes.tvar a) + List.map (fun (a, _) -> (tvar a, [])) (EcEnv.LDecl.tohyps hyps).h_tvar in let pt = ptglobal ~tys p in @@ -753,8 +754,10 @@ let process_delta ~und_delta ?target ((s :rwside), o, p) tc = in - let ri = { EcReduction.full_red with - delta_p = (fun p -> if Some p = dp then `Force else `IfTransparent)} in + let ri = + let delta_p p = + if Some p = dp then `Force else `IfTransparent + in { EcReduction.full_red with delta_p } in let na = List.length args in match s with @@ -791,9 +794,12 @@ let process_delta ~und_delta ?target ((s :rwside), o, p) tc = match sform_of_form fp with | SFop ((_, tvi), []) -> begin - (* FIXME: TC HOOK *) - let body = Tvar.f_subst ~freshen:true tparams tvi body in - let body = f_app body args topfp.f_ty in + let body = + Tvar.f_subst + ~freshen:true + (List.combine (List.map fst tparams) tvi) + body in + let body = f_app body args topfp.f_ty in try EcReduction.h_red EcReduction.beta_red hyps body with EcEnv.NotReducible -> body end @@ -814,9 +820,13 @@ let process_delta ~und_delta ?target ((s :rwside), o, p) tc = | `RtoL -> let fp = - (* FIXME: TC HOOK *) - let body = Tvar.f_subst ~freshen:true tparams tvi body in - let fp = f_app body args p.f_ty in + let body = + Tvar.f_subst + ~freshen:true + (List.combine (List.map fst tparams) tvi) + body + in + let fp = f_app body args p.f_ty in try EcReduction.h_red EcReduction.beta_red hyps fp with EcEnv.NotReducible -> fp in @@ -1542,7 +1552,10 @@ let rec process_mintros_1 ?(cf = true) ttenv pis gs = end in - let tc = t_ors [t_elimT_ind `Case; t_elim; t_elim_prind `Case] in + let tc = t_ors [ + t_elimT_ind ~reduce:`Full `Case; + t_elim ~reduce:`Full; + t_elim_prind ~reduce:`Full `Case] in let tc = fun g -> try tc g @@ -2176,7 +2189,11 @@ let process_split ?(i : int option) (tc : tcenv1) = let process_elim (pe, qs) tc = let doelim tc = match qs with - | None -> t_or (t_elimT_ind `Ind) t_elim tc + | None -> + t_or + (t_elimT_ind ~reduce:`Full `Ind) + (t_elim ~reduce:`Full) + tc | Some qs -> let qs = { fp_mode = `Implicit; @@ -2222,7 +2239,10 @@ let process_case ?(doeq = false) gp tc = with E.LEMFailure -> try FApi.t_last - (t_ors [t_elimT_ind `Case; t_elim; t_elim_prind `Case]) + (t_ors [ + t_elimT_ind ~reduce:`Full `Case; + t_elim ~reduce:`Full; + t_elim_prind ~reduce:`Full `Case]) (process_move ~doeq gp.pr_view gp.pr_rev tc) with EcCoreGoal.InvalidGoalShape -> diff --git a/src/ecHiInductive.ml b/src/ecHiInductive.ml index 9084f1118a..9a5dae1122 100644 --- a/src/ecHiInductive.ml +++ b/src/ecHiInductive.ml @@ -84,8 +84,10 @@ let trans_datatype (env : EcEnv.env) (name : ptydname) (dt : pdatatype) = let env0 = let myself = { tyd_params = EcUnify.UniEnv.tparams ue; - tyd_type = Abstract; + tyd_type = `Abstract []; + tyd_resolve = true; tyd_loca = lc; + tyd_subtype = None; } in EcEnv.Ty.bind (unloc name) myself env in @@ -131,19 +133,19 @@ let trans_datatype (env : EcEnv.env) (name : ptydname) (dt : pdatatype) = let tdecl = EcEnv.Ty.by_path_opt tname env0 |> odfl (EcDecl.abs_tydecl ~params:(`Named tparams) lc) in - let tyinst = ty_instantiate tdecl.tyd_params targs in + let tyinst = ty_instanciate tdecl.tyd_params targs in match tdecl.tyd_type with - | Abstract -> - List.exists isempty targs + | `Abstract _ -> + List.exists isempty (List.fst targs) - | Concrete ty -> + | `Concrete ty -> isempty_1 [ tyinst ty ] - | Record (_, fields) -> + | `Record (_, fields) -> isempty_1 (List.map (tyinst -| snd) fields) - | Datatype dt -> + | `Datatype dt -> (* FIXME: Inspecting all constructors recursively causes non-termination in some cases. One can have the same limitation as is done for positivity in order to limit this @@ -333,7 +335,7 @@ let trans_matchfix | PPApp ((cname, tvi), _cargs) -> let tvi = tvi |> omap (TT.transtvi env ue) in let filter = fun _ op -> EcDecl.is_ctor op in - let cts = EcUnify.select_op ~filter tvi env (unloc cname) ue ([], None) in + let cts = EcUnify.select_op ~filter tvi env (unloc cname) ue [] in match cts with | [] -> fxerror cname.pl_loc env TT.FXE_CtorUnk @@ -368,7 +370,7 @@ let trans_matchfix let indp, _ = Msym.find x indtbl in let indty = oget (EcEnv.Ty.by_path_opt indp env) in let ind = (oget (EcDecl.tydecl_as_datatype indty)).tydt_ctors in - let codom = tconstr indp (List.map tvar indty.tyd_params) in + let codom = tconstr_tc indp (etyargs_of_tparams indty.tyd_params) in let tys = List.map (fun (_, dom) -> toarrow dom codom) ind in let tys, _ = EcUnify.UniEnv.opentys ue indty.tyd_params None tys in let doargs cty = @@ -380,7 +382,7 @@ let trans_matchfix | PPApp ((cname, tvi), cargs) -> let filter = fun _ op -> EcDecl.is_ctor op in let tvi = tvi |> omap (TT.transtvi env ue) in - let cts = EcUnify.select_op ~filter tvi env (unloc cname) ue ([], None) in + let cts = EcUnify.select_op ~filter tvi env (unloc cname) ue [] in match cts with | [] -> @@ -411,8 +413,8 @@ let trans_matchfix EcUnify.UniEnv.restore ~src:subue ~dst:ue; let ctorty = - let tvi = Some (EcUnify.TVIunamed tvi) in - fst (EcUnify.UniEnv.opentys ue indty.tyd_params tvi ctorty) in + let tvi = Some (EcUnify.tvi_unamed tvi) in + fst (EcUnify.UniEnv.opentys ue indty.tyd_params tvi ctorty) in let pty = EcUnify.UniEnv.fresh ue in (try EcUnify.unify env ue (toarrow ctorty pty) opty @@ -483,7 +485,7 @@ let trans_matchfix let codom = ty_subst ts codom in let opexpr = EcPath.pqname (EcEnv.root env) name in let args = List.map (snd_map (ty_subst ts)) args in - let opexpr = e_op opexpr (List.map tvar tparams) + let opexpr = e_op_tc opexpr (etyargs_of_tparams tparams) (toarrow (List.map snd args) codom) in let ebsubst = bind_elocal ts opname opexpr diff --git a/src/ecHiNotations.ml b/src/ecHiNotations.ml index 3d742857c5..6ed6ac979e 100644 --- a/src/ecHiNotations.ml +++ b/src/ecHiNotations.ml @@ -12,7 +12,7 @@ module TT = EcTyping (* -------------------------------------------------------------------- *) type nterror = | NTE_Typing of EcTyping.tyerror -| NTE_TyNotClosed +| NTE_TyNotClosed of EcUnify.uniflags | NTE_DupIdent | NTE_UnknownBinder of symbol | NTE_AbbrevIsVar @@ -62,8 +62,8 @@ let trans_notation_r (env : env) (nt : pnotation located) = let codom = TT.transty TT.tp_relax env ue nt.nt_codom in let body = TT.transexpcast benv `InOp ue codom nt.nt_body in - if not (EcUnify.UniEnv.closed ue) then - nterror gloc env NTE_TyNotClosed; + Option.iter (fun infos -> nterror gloc env (NTE_TyNotClosed infos)) + @@ EcUnify.UniEnv.xclosed ue; ignore body; () @@ -80,11 +80,13 @@ let trans_abbrev_r (env : env) (at : pabbrev located) = let codom = TT.transty TT.tp_relax env ue (fst at.ab_def) in let body = TT.transexpcast benv `InOp ue codom (snd at.ab_def) in - if not (EcUnify.UniEnv.closed ue) then - nterror gloc env NTE_TyNotClosed; + Option.iter (fun infos -> nterror gloc env (NTE_TyNotClosed infos)) + @@ EcUnify.UniEnv.xclosed ue; - let ts = Tuni.subst (EcUnify.UniEnv.close ue) in - let es = e_subst ts in + let ts = Tuni.subst + ~tw_uni:(EcUnify.UniEnv.tw_assubst ue) + (EcUnify.UniEnv.close ue) in + let es = e_subst ts in let body = es body in let codom = ty_subst ts codom in let xs = List.map (snd_map (ty_subst ts)) xs in diff --git a/src/ecHiNotations.mli b/src/ecHiNotations.mli index 54dd54543e..53aa868c15 100644 --- a/src/ecHiNotations.mli +++ b/src/ecHiNotations.mli @@ -8,7 +8,7 @@ open EcEnv (* -------------------------------------------------------------------- *) type nterror = | NTE_Typing of EcTyping.tyerror -| NTE_TyNotClosed +| NTE_TyNotClosed of EcUnify.uniflags | NTE_DupIdent | NTE_UnknownBinder of symbol | NTE_AbbrevIsVar diff --git a/src/ecHiPredicates.ml b/src/ecHiPredicates.ml index 49e725ad58..9fba05c55b 100644 --- a/src/ecHiPredicates.ml +++ b/src/ecHiPredicates.ml @@ -2,7 +2,6 @@ open EcUtils open EcSymbols open EcLocation -open EcTypes open EcCoreSubst open EcParsetree open EcDecl @@ -11,8 +10,8 @@ module TT = EcTyping (* -------------------------------------------------------------------- *) type tperror = -| TPE_Typing of EcTyping.tyerror -| TPE_TyNotClosed +| TPE_Typing of EcTyping.tyerror +| TPE_TyNotClosed of EcUnify.uniflags | TPE_DuplicatedConstr of symbol exception TransPredError of EcLocation.t * EcEnv.env * tperror @@ -20,8 +19,8 @@ exception TransPredError of EcLocation.t * EcEnv.env * tperror let tperror loc env e = raise (TransPredError (loc, env, e)) (* -------------------------------------------------------------------- *) -let close_pr_body (uni : ty EcUid.Muid.t) (body : prbody) = - let fsubst = EcFol.Fsubst.f_subst_init ~tu:uni () in +let close_pr_body (uidmap : EcTypes.ty EcAst.TyUni.Muid.t) (body : prbody) = + let fsubst = EcFol.Fsubst.f_subst_init ~tu:uidmap () in let tsubst = ty_subst fsubst in match body with @@ -74,13 +73,13 @@ let trans_preddecl_r (env : EcEnv.env) (pr : ppredicate located) = in - if not (EcUnify.UniEnv.closed ue) then - tperror loc env TPE_TyNotClosed; + Option.iter + (fun infos -> tperror loc env (TPE_TyNotClosed infos)) + (EcUnify.UniEnv.xclosed ue); - let uidmap = EcUnify.UniEnv.assubst ue in + let uidmap = EcUnify.UniEnv.assubst ue in let tparams = EcUnify.UniEnv.tparams ue in let body = body |> omap (close_pr_body uidmap) in - let dom = Tuni.subst_dom uidmap dom in EcDecl.mk_pred ~opaque:optransparent tparams dom body pr.pp_locality diff --git a/src/ecHiPredicates.mli b/src/ecHiPredicates.mli index eb56da6628..f411802cce 100644 --- a/src/ecHiPredicates.mli +++ b/src/ecHiPredicates.mli @@ -5,8 +5,8 @@ open EcParsetree (* -------------------------------------------------------------------- *) type tperror = -| TPE_Typing of EcTyping.tyerror -| TPE_TyNotClosed +| TPE_Typing of EcTyping.tyerror +| TPE_TyNotClosed of EcUnify.uniflags | TPE_DuplicatedConstr of symbol exception TransPredError of EcLocation.t * EcEnv.env * tperror diff --git a/src/ecIdent.ml b/src/ecIdent.ml index 9487c0e42c..3968b34c69 100644 --- a/src/ecIdent.ml +++ b/src/ecIdent.ml @@ -57,3 +57,4 @@ let tostring_internal (id : t) = (* -------------------------------------------------------------------- *) let pp_ident fmt id = Format.fprintf fmt "%s" (name id) +let pp = pp_ident diff --git a/src/ecIdent.mli b/src/ecIdent.mli index 82942edb77..659a7e474c 100644 --- a/src/ecIdent.mli +++ b/src/ecIdent.mli @@ -38,3 +38,4 @@ val fv_add : ident -> int Mid.t -> int Mid.t (* -------------------------------------------------------------------- *) val pp_ident : Format.formatter -> t -> unit +val pp : Format.formatter -> t -> unit diff --git a/src/ecInductive.ml b/src/ecInductive.ml index 81f3be80d8..eb146ee9a0 100644 --- a/src/ecInductive.ml +++ b/src/ecInductive.ml @@ -38,15 +38,15 @@ let datatype_proj_path (p : EP.path) (x : symbol) = (* -------------------------------------------------------------------- *) let indsc_of_record (rc : record) = - let targs = List.map tvar rc.rc_tparams in - let recty = tconstr rc.rc_path targs in + let targs = etyargs_of_tparams rc.rc_tparams in + let recty = tconstr_tc rc.rc_path targs in let recx = fresh_id_of_ty recty in let recfm = FL.f_local recx recty in let predty = tfun recty tbool in let predx = EcIdent.create "P" in let pred = FL.f_local predx predty in let ctor = record_ctor_path rc.rc_path in - let ctor = FL.f_op ctor targs (toarrow (List.map snd rc.rc_fields) recty) in + let ctor = FL.f_op_tc ctor targs (toarrow (List.map snd rc.rc_fields) recty) in let prem = let ids = List.map (fun (_, fty) -> (fresh_id_of_ty fty, fty)) rc.rc_fields in let vars = List.map (fun (x, xty) -> FL.f_local x xty) ids in @@ -129,7 +129,7 @@ let rec occurs ?(normty = identity) p t = (** Tests whether the first list is a list of type variables, matching the identifiers of the second list. *) let ty_params_compat = - List.for_all2 (fun ty param_id -> + List.for_all2 (fun (ty, _) (param_id, _) -> match ty.ty_node with | Tvar id -> EcIdent.id_equal id param_id | _ -> false) @@ -142,13 +142,13 @@ let rec check_positivity_in_decl fct p decl ident = and iter l f = List.iter f l in match decl.tyd_type with - | Concrete ty -> with_context ~ident p Concrete (check ty) - | Abstract -> non_positive p AbstractTypeRestriction - | Datatype { tydt_ctors } -> + | `Concrete ty -> with_context ~ident p Concrete (check ty) + | `Abstract _ -> non_positive p AbstractTypeRestriction + | `Datatype { tydt_ctors; _ } -> iter tydt_ctors @@ fun (name, argty) -> iter argty @@ fun ty -> with_context ~ident p (Variant name) (check ty) - | Record (_, tys) -> + | `Record (_, tys) -> iter tys @@ fun (name, ty) -> with_context ~ident p (Record name) (check ty) @@ -162,9 +162,9 @@ and check_positivity_ident fct p params ident ty = non_positive p (TypePositionRestriction ty) | Tconstr (q, args) -> let decl = fct q in - List.iter (check_positivity_ident fct p params ident) args; + List.iter (fun (a, _) -> check_positivity_ident fct p params ident a) args; List.combine args decl.tyd_params - |> List.filter_map (fun (arg, ident') -> + |> List.filter_map (fun ((arg, _), (ident', _)) -> if EcTypes.var_mem ident arg then Some ident' else None) |> List.iter (check_positivity_in_decl fct q decl) | Tfun (from, to_) -> @@ -177,12 +177,12 @@ let rec check_positivity_path fct p ty = | Tglob _ | Tunivar _ | Tvar _ -> () | Ttuple tys -> List.iter (check_positivity_path fct p) tys | Tconstr (q, args) when EcPath.p_equal q p -> - if List.exists (occurs p) args then non_positive p (NonPositiveOcc ty) + if List.exists (fun (a, _) -> occurs p a) args then non_positive p (NonPositiveOcc ty) | Tconstr (q, args) -> let decl = fct q in - List.iter (check_positivity_path fct p) args; + List.iter (fun (a, _) -> check_positivity_path fct p a) args; List.combine args decl.tyd_params - |> List.filter_map (fun (arg, ident) -> + |> List.filter_map (fun ((arg, _), (ident, _)) -> if occurs p arg then Some ident else None) |> List.iter (check_positivity_in_decl fct q decl) | Tfun (from, to_) -> @@ -223,11 +223,11 @@ let indsc_of_datatype ?(normty = identity) (mode : indmode) (dt : datatype) = |> omap (FL.f_forall [x, GTty ty1]) and schemec mode (targs, p) pred (ctor, tys) = - let indty = tconstr p (List.map tvar targs) in + let indty = tconstr_tc p targs in let xs = List.map (fun xty -> (fresh_id_of_ty xty, xty)) tys in let cargs = List.map (fun (x, xty) -> FL.f_local x xty) xs in let ctor = EcPath.pqoname (EcPath.prefix tpath) ctor in - let ctor = FL.f_op ctor (List.map tvar targs) (toarrow tys indty) in + let ctor = FL.f_op_tc ctor targs (toarrow tys indty) in let form = FL.f_app pred [FL.f_app ctor cargs indty] tbool in let form = match mode with @@ -247,7 +247,7 @@ let indsc_of_datatype ?(normty = identity) (mode : indmode) (dt : datatype) = form and scheme mode (targs, p) ctors = - let indty = tconstr p (List.map tvar targs) in + let indty = tconstr_tc p targs in let indx = fresh_id_of_ty indty in let indfm = FL.f_local indx indty in let predty = tfun indty tbool in @@ -260,11 +260,11 @@ let indsc_of_datatype ?(normty = identity) (mode : indmode) (dt : datatype) = let form = FL.f_forall [predx, GTty predty] form in form - in scheme mode (dt.dt_tparams, tpath) dt.dt_ctors + in scheme mode (etyargs_of_tparams dt.dt_tparams, tpath) dt.dt_ctors (* -------------------------------------------------------------------- *) let datatype_projectors (tpath, tparams, { tydt_ctors = ctors }) = - let thety = tconstr tpath (List.map tvar tparams) in + let thety = tconstr_tc tpath (etyargs_of_tparams tparams) in let do1 i (cname, cty) = let thv = EcIdent.create "the" in @@ -378,7 +378,7 @@ let indsc_of_prind ({ ip_path = p; ip_prind = pri } as pr) = FL.f_forall ctor.prc_bds px in - let sc = FL.f_op p (List.map tvar pr.ip_tparams) prty in + let sc = FL.f_op_tc p (etyargs_of_tparams pr.ip_tparams) prty in let sc = FL.f_imp (FL.f_app sc prag tbool) pred in let sc = FL.f_imps (List.map for1 pri.pri_ctors) sc in let sc = FL.f_forall [predx, FL.gtty tbool] sc in @@ -391,7 +391,7 @@ let introsc_of_prind ({ ip_path = p; ip_prind = pri } as pr) = let bds = List.map (snd_map FL.gtty) pri.pri_args in let clty = toarrow (List.map snd pri.pri_args) tbool in let clag = (List.map (curry FL.f_local) pri.pri_args) in - let cl = FL.f_op p (List.map tvar pr.ip_tparams) clty in + let cl = FL.f_op_tc p (etyargs_of_tparams pr.ip_tparams) clty in let cl = FL.f_app cl clag tbool in let for1 ctor = diff --git a/src/ecLexer.mll b/src/ecLexer.mll index 704b0e9764..0ca9d885d0 100644 --- a/src/ecLexer.mll +++ b/src/ecLexer.mll @@ -199,6 +199,7 @@ "theory" , THEORY ; (* KW: global *) "abstract" , ABSTRACT ; (* KW: global *) "section" , SECTION ; (* KW: global *) + "class" , CLASS ; (* KW: global *) "subtype" , SUBTYPE ; (* KW: global *) "type" , TYPE ; (* KW: global *) "instance" , INSTANCE ; (* KW: global *) diff --git a/src/ecLowGoal.ml b/src/ecLowGoal.ml index ae5a28eb0f..ba2464cb7f 100644 --- a/src/ecLowGoal.ml +++ b/src/ecLowGoal.ml @@ -168,7 +168,7 @@ module LowApply = struct | PTGlobal (p, tys) -> (* FIXME: poor API ==> poor error recovery *) let env = LDecl.toenv (hyps_of_ckenv tc) in - (pt, EcEnv.Ax.instantiate p tys env, subgoals) + (pt, EcEnv.Ax.instanciate p tys env, subgoals) | PTTerm pt -> let pt, ax, subgoals = check_ `Elim pt subgoals tc in @@ -743,9 +743,14 @@ let t_apply_hyp (x : EcIdent.t) ?args ?sk tc = let t_hyp (x : EcIdent.t) tc = t_apply_hyp x ~args:[] ~sk:0 tc +(* -------------------------------------------------------------------- *) +let t_apply_s_tc (p : path) (etys : etyarg list) ?args ?sk tc = + tt_apply_s p etys ?args ?sk (FApi.tcenv_of_tcenv1 tc) + (* -------------------------------------------------------------------- *) let t_apply_s (p : path) (tys : ty list) ?args ?sk tc = - tt_apply_s p tys ?args ?sk (FApi.tcenv_of_tcenv1 tc) + let etys = List.map (fun ty -> (ty, [])) tys in + tt_apply_s p etys ?args ?sk (FApi.tcenv_of_tcenv1 tc) (* -------------------------------------------------------------------- *) let t_apply_hd (hd : handle) ?args ?sk tc = @@ -1009,7 +1014,7 @@ let t_true (tc : tcenv1) = let t_reflex_s (f : form) (tc : tcenv1) = t_apply_s LG.p_eq_refl [f.f_ty] ~args:[f] tc -let t_reflex ?(mode=`Conv) ?reduce (tc : tcenv1) = +let t_reflex ?(mode = `Conv) ?reduce (tc : tcenv1) = let t_reflex_r (fp : form) (tc : tcenv1) = match sform_of_form fp with | SFeq (f1, f2) -> @@ -1171,9 +1176,9 @@ let t_elim_r ?(reduce = (`Full : lazyred)) txs tc = | None -> begin let strategy = match reduce with - | `None -> raise InvalidGoalShape - | `Full -> EcReduction.full_red - | `NoDelta -> EcReduction.nodelta in + | `None -> raise InvalidGoalShape + | `Full -> EcReduction.full_red + | `NoDelta -> EcReduction.nodelta in match h_red_opt strategy (FApi.tc1_hyps tc) f1 with | None -> raise InvalidGoalShape @@ -1508,9 +1513,9 @@ let t_elim_prind_r ?reduce ?accept (_mode : [`Case | `Ind]) tc = end; (oget (EcEnv.Op.scheme_of_prind env `Case p), tv, args) - | _ -> raise InvalidGoalShape + | _ -> raise InvalidGoalShape in - in t_apply_s p tv ~args:(args @ [f2]) ~sk tc + t_apply_s_tc p tv ~args:(args @ [f2]) ~sk tc | _ -> raise TTC.NoMatch @@ -1665,7 +1670,7 @@ let t_split_prind ?reduce (tc : tcenv1) = | None -> raise InvalidGoalShape | Some (x, sk) -> let p = EcInductive.prind_introsc_path p x in - t_apply_s p tv ~args ~sk tc + t_apply_s_tc p tv ~args ~sk tc in t_lazy_match ?reduce t_split_r tc @@ -1685,10 +1690,10 @@ let t_or_intro_prind ?reduce (side : side) (tc : tcenv1) = match EcInductive.prind_is_iso_ors pri with | Some ((x, sk), _) when side = `Left -> let p = EcInductive.prind_introsc_path p x in - t_apply_s p tv ~args ~sk tc + t_apply_s_tc p tv ~args ~sk tc | Some (_, (x, sk)) when side = `Right -> let p = EcInductive.prind_introsc_path p x in - t_apply_s p tv ~args ~sk tc + t_apply_s_tc p tv ~args ~sk tc | _ -> raise InvalidGoalShape in t_lazy_match ?reduce t_split_r tc @@ -1889,9 +1894,14 @@ module LowSubst = struct (* check if x is a declared module *) let fv = Sid.add x fv in if EcEnv.Mod.by_mpath_opt (EcPath.mident x) env <> None then fv + (* [f.f_fv] also collects type-variables (which live in + [h_tvar], not [h_local]) and other non-hypothesis idents; + a raw [LDecl.by_id] would crash with [LookupError]. Only + expand let-bound locals. *) else match LDecl.by_id x hyps with | LD_var (_, Some f) -> add_f fv f | _ -> fv + | exception LDecl.LdeclError _ -> fv and add_f fv f = Mid.fold_left add fv f.f_fv in Some(side,v,f, add_f Sid.empty f) @@ -2288,8 +2298,7 @@ let t_progress ?options ?ti (tt : FApi.backward) (tc : tcenv1) = else elims in - let reduce = - if options.pgo_delta.pgod_case then `Full else `NoDelta in + let reduce = if options.pgo_delta.pgod_case then `Full else `NoDelta in FApi.t_switch ~on:`All (t_elim_r ~reduce elims) ~ifok:aux0 ~iffail tc end diff --git a/src/ecLowGoal.mli b/src/ecLowGoal.mli index c17b4e4b28..a9b806d429 100644 --- a/src/ecLowGoal.mli +++ b/src/ecLowGoal.mli @@ -18,7 +18,6 @@ exception InvalidProofTerm (* invalid proof term *) type side = [`Left|`Right] type lazyred = [`Full | `NoDelta | `None] - (* -------------------------------------------------------------------- *) val (@!) : FApi.backward -> FApi.backward -> FApi.backward val (@+) : FApi.backward -> FApi.backward list -> FApi.backward @@ -113,6 +112,8 @@ val t_apply : ?cutsolver:cutsolver -> proofterm -> FApi.backward * skip before applying [p]. *) val t_apply_s : path -> ty list -> ?args:(form list) -> ?sk:int -> FApi.backward +val t_apply_s_tc : path -> etyarg list -> ?args:(form list) -> ?sk:int -> FApi.backward + (* Apply a proof term of the form [h f1...fp _ ... _] constructed from * the local hypothesis and formulas given to the function. The [int] * argument gives the number of premises to skip before applying @@ -191,7 +192,7 @@ val t_elim_iso_or : ?reduce:lazyred -> tcenv1 -> int list * tcenv (* Elimination using an custom elimination principle. *) val t_elimT_form : proofterm -> ?sk:int -> form -> FApi.backward -val t_elimT_form_global : path -> ?typ:(ty list) -> ?sk:int -> form -> FApi.backward +val t_elimT_form_global : path -> ?typ:(etyarg list) -> ?sk:int -> form -> FApi.backward (* Eliminiation using an elimation principle of an induction type *) val t_elimT_ind : ?reduce:lazyred -> [ `Case | `Ind ] -> FApi.backward diff --git a/src/ecMatching.ml b/src/ecMatching.ml index 29a4617cb0..b9dae924c5 100644 --- a/src/ecMatching.ml +++ b/src/ecMatching.ml @@ -389,10 +389,10 @@ module Position = struct let (env, s), npath = normalize_cpos_path env cpath s in (env, s), (npath, normalize_cpos1 env cp1 s) - let resolve_offset1_from_cpos1 env (base: nm_codepos1) (off: codeoffset1) (s: stmt) : nm_codepos1 = + let resolve_offset1_from_cpos1 env (base: nm_codepos1) (off: codeoffset1) (s: stmt) : nm_codepos1 = match off with - | `Absolute off -> normalize_cpos1 env off s - | `Relative i -> + | `Absolute off -> normalize_cpos1 env off s + | `Relative i -> let nm = (base + i) in check_nm_cpos1 nm s; nm @@ -828,7 +828,9 @@ module MEV = struct v let assubst ue ev env = - let subst = f_subst_init ~tu:(EcUnify.UniEnv.assubst ue) () in + let subst = f_subst_init + ~tu:(EcUnify.UniEnv.assubst ue) + ~tw_uni:(EcUnify.UniEnv.tw_assubst ue) () in let subst = EV.fold (fun x m s -> Fsubst.f_bind_mem s x m) ev.evm_mem subst in let subst = EV.fold (fun x mp s -> EcFol.f_bind_mod s x mp env) ev.evm_mod subst in let seen = ref Sid.empty in @@ -1048,7 +1050,7 @@ let f_match_core opts hyps (ue, ev) f1 f2 = | Fop (op1, tys1), Fop (op2, tys2) -> begin if not (EcPath.p_equal op1 op2) then failure (); - try List.iter2 (EcUnify.unify env ue) tys1 tys2 + try List.iter2 (EcUnify.unify_etyarg env ue) tys1 tys2 with EcUnify.UnificationFailure _ -> failure () end @@ -1144,6 +1146,55 @@ let f_match_core opts hyps (ue, ev) f1 f2 = failure (); doit env (subst, mxs) f1' f2' in + (* Eta-reduce a [fun (x_1 ... x_n) => h x_1 ... x_n] body when + [h] does not mention any [x_i]. Returns [Some h] on success. *) + let try_eta_reduce (f : form) : form option = + match f.f_node with + | Fquant (Llambda, bd, body) -> begin + let nbd = List.length bd in + match destr_app body with + | (h, args) when List.length args >= nbd -> + let n_extra = List.length args - nbd in + let extra, tail = List.split_at n_extra args in + let bd_ids = List.map fst bd in + (* Tail must be exactly [x_1; ...; x_n] in order. *) + let tail_ok = + List.for_all2 (fun (x, _) a -> + match a.f_node with + | Flocal y -> EcIdent.id_equal x y + | _ -> false) bd tail in + (* And [h] (with extras) must not mention the [x_i]. *) + let captures = + List.exists (fun id -> Mid.mem id h.f_fv) bd_ids + || List.exists + (fun a -> List.exists (fun id -> Mid.mem id a.f_fv) bd_ids) + extra in + if tail_ok && not captures then + Some (if n_extra = 0 then h else f_app h extra body.f_ty) + else None + | _ -> None + end + | _ -> None in + + let is_lambda f = + match f.f_node with Fquant (Llambda, _, _) -> true | _ -> false in + let try_etared () = + (* Only η-reduce when the other side is not itself a lambda; + if both are lambdas, the structural Fquant/Fquant case + handles it, and prematurely eta-reducing one side would + interfere with higher-order matching against lambda + patterns. *) + match f1.f_node, f2.f_node with + | Fquant (Llambda, _, _), _ when not (is_lambda f2) -> + (match try_eta_reduce f1 with + | Some f1' -> doit env (subst, mxs) f1' f2 + | None -> failure ()) + | _, Fquant (Llambda, _, _) when not (is_lambda f1) -> + (match try_eta_reduce f2 with + | Some f2' -> doit env (subst, mxs) f1 f2' + | None -> failure ()) + | _ -> failure () in + let try_horder () = if not opts.fm_horder then failure (); @@ -1166,7 +1217,16 @@ let f_match_core opts hyps (ue, ev) f1 f2 = let try_delta () = if not opts.fm_delta then failure (); - + (* Drain pending TC constraints before checking [tc_reducible]: + a [TCIUni] witness on a TC op-head needs to be committed in + the resolution map (and then dereferenced via [norm]) for + [tc_core_reduce] to fire. Without this drain, a parametric- + carrier proof-term carrying an unresolved [TCIUni] would + fail to reduce here even when the carrier's TC instance is + registered in the env. *) + EcUnify.UniEnv.flush_tc_problems env ue; + let f1 = norm f1 in + let f2 = norm f2 in match fst_map f_node (destr_app f1), fst_map f_node (destr_app f2) with @@ -1182,6 +1242,12 @@ let f_match_core opts hyps (ue, ev) f1 f2 = | _, (Fop (op2, tys2), args2) when EcEnv.Op.reducible env op2 -> doit_reduce env (doit env ilc f1) f2.f_ty op2 tys2 args2 + | (Fop (op1, tys1), args1), _ when EcEnv.Op.tc_reducible env op1 tys1 -> + doit_tc_reduce env ((doit env ilc)^~ f2) f1.f_ty op1 tys1 args1 + + | _, (Fop (op2, tys2), args2) when EcEnv.Op.tc_reducible env op2 tys2 -> + doit_tc_reduce env (doit env ilc f1) f2.f_ty op2 tys2 args2 + | _, _ -> failure () in @@ -1193,7 +1259,7 @@ let f_match_core opts hyps (ue, ev) f1 f2 = List.find_map_opt (fun doit -> try Some (doit ()) with MatchFailure -> None) - [try_betared; try_horder; try_delta; default] + [try_betared; try_horder; try_etared; try_delta; default] |> oget ~exn:MatchFailure and doit_args env ilc fs1 fs2 = @@ -1207,6 +1273,12 @@ let f_match_core opts hyps (ue, ev) f1 f2 = with NotReducible -> raise MatchFailure in cb (odfl reduced (EcReduction.h_red_opt EcReduction.beta_red hyps reduced)) + and doit_tc_reduce env cb ty op tys args = + let reduced = + try f_app (EcEnv.Op.tc_reduce env op tys) args ty + with NotReducible -> raise MatchFailure in + cb (odfl reduced (EcReduction.h_red_opt EcReduction.beta_red hyps reduced)) + and doit_lreduce _env cb ty x args = let reduced = try f_app (LDecl.unfold x hyps) args ty @@ -1292,7 +1364,7 @@ let f_match opts hyps (ue, ev) f1 f2 = raise MatchFailure; let clue = try EcUnify.UniEnv.close ue - with EcUnify.UninstantiateUni -> raise MatchFailure + with EcUnify.UninstanciateUni _ -> raise MatchFailure in (ue, clue, ev) diff --git a/src/ecMatching.mli b/src/ecMatching.mli index d13622e4d7..cf01bb8cd0 100644 --- a/src/ecMatching.mli +++ b/src/ecMatching.mli @@ -1,6 +1,5 @@ (* -------------------------------------------------------------------- *) open EcMaps -open EcUid open EcIdent open EcTypes open EcModules @@ -384,7 +383,7 @@ val f_match : -> unienv * mevmap -> form -> form - -> unienv * (ty Muid.t) * mevmap + -> unienv * (ty EcAst.TyUni.Muid.t) * mevmap (* -------------------------------------------------------------------- *) type ptnpos = private [`Select of int | `Sub of ptnpos] Mint.t diff --git a/src/ecPV.ml b/src/ecPV.ml index 3d173b9b67..5eff3d8adb 100644 --- a/src/ecPV.ml +++ b/src/ecPV.ml @@ -1011,7 +1011,7 @@ module Mpv2 = struct when EcIdent.id_equal ml m1 && EcIdent.id_equal mr m2 -> add_glob env (EcPath.mident mp1) (EcPath.mident mp2) eqs | Fop(op1,tys1), Fop(op2,tys2) when EcPath.p_equal op1 op2 && - List.all2 (EcReduction.EqTest.for_type env) tys1 tys2 -> eqs + List.all2 (fun (t1, _) (t2, _) -> EcReduction.EqTest.for_type env t1 t2) tys1 tys2 -> eqs | Fapp(f1,a1), Fapp(f2,a2) -> List.fold_left2 (add_eq local) eqs (f1::a1) (f2::a2) | Ftuple es1, Ftuple es2 -> @@ -1110,7 +1110,7 @@ module Mpv2 = struct I postpone this for latter *) | Eop(op1,tys1), Eop(op2,tys2) when EcPath.p_equal op1 op2 && - List.all2 (EcReduction.EqTest.for_type env) tys1 tys2 -> eqs + List.all2 (fun (t1, _) (t2, _) -> EcReduction.EqTest.for_type env t1 t2) tys1 tys2 -> eqs | Eapp(f1,a1), Eapp(f2,a2) -> List.fold_left2 (add_eqs_loc env local) eqs (f1::a1) (f2::a2) | Elet(lp1,a1,b1), Elet(lp2,a2,b2) -> diff --git a/src/ecParser.mly b/src/ecParser.mly index 65cd97b95f..240c30cf90 100644 --- a/src/ecParser.mly +++ b/src/ecParser.mly @@ -90,17 +90,18 @@ let mk_simplify l = if l = [] then - { pbeta = true; pzeta = true; - piota = true; peta = true; - plogic = true; pdelta = None; - pmodpath = true; puser = true; } + { pbeta = true; pzeta = true; + piota = true; peta = true; + plogic = true; pdelta = None; + pdeltatc = true; pmodpath = true; + puser = true; } else let doarg acc = function | `Delta l -> if l = [] || acc.pdelta = None then { acc with pdelta = None } else { acc with pdelta = Some (oget acc.pdelta @ l) } - + | `DeltaTC -> { acc with pdeltatc = true } | `Zeta -> { acc with pzeta = true } | `Iota -> { acc with piota = true } | `Beta -> { acc with pbeta = true } @@ -110,10 +111,11 @@ | `User -> { acc with puser = true } in List.fold_left doarg - { pbeta = false; pzeta = false; - piota = false; peta = false; - plogic = false; pdelta = Some []; - pmodpath = false; puser = false; } l + { pbeta = false; pzeta = false; + piota = false; peta = false; + plogic = false; pdelta = Some []; + pdeltatc = false; pmodpath = false; + puser = false; } l let simplify_red = [`Zeta; `Iota; `Beta; `Eta; `Logic; `ModPath; `User] @@ -404,6 +406,7 @@ %token CEQ %token CFOLD %token CHANGE +%token CLASS %token CLEAR %token CLONE %token COLON @@ -1635,6 +1638,7 @@ signature_item: pfd_uses = { pmre_name = x; pmre_orcls = orcls; } } } (* -------------------------------------------------------------------- *) +(* EcTypes declarations / definitions *) %inline locality: | (* empty *) { `Global } | LOCAL { `Local } @@ -1649,19 +1653,29 @@ signature_item: %inline is_local: | lc=loc(locality) { locality_as_local lc } -(* -------------------------------------------------------------------- *) -(* EcTypes declarations / definitions *) +tcparam: +| tys=ioption(type_args) x=lqident + { (x, odfl [] tys) } + +tc_parent: +| p=tcparam + { (p, []) } +| LPAREN p=tcparam WITH ren=plist1(tc_rename, COMMA) RPAREN + { (p, ren) } + +tc_rename: +| src=oident EQ tgt=oident { (src, tgt) } typaram: -| x=tident - { (x : ptyparam) } +| x=tident { (x, []) } +| x=tident LTCOLON tc=plist1(tcparam, AMP) { (x, tc) } typarams: | empty { ([] : ptyparams) } | x=tident - { ([x] : ptyparams) } + { ([(x, [])] : ptyparams) } | xs=paren(plist1(typaram, COMMA)) { (xs : ptyparams) } @@ -1685,7 +1699,10 @@ rec_field_def: typedecl: | locality=locality TYPE td=rlist1(tyd_name, COMMA) - { List.map (fun x -> mk_tydecl ~locality x PTYD_Abstract) td } + { List.map (fun x -> mk_tydecl ~locality x (PTYD_Abstract [])) td } + +| locality=locality TYPE td=tyd_name LTCOLON tcs=rlist1(tcparam, AMP) + { [mk_tydecl ~locality td (PTYD_Abstract tcs)] } | locality=locality TYPE td=tyd_name EQ te=loc(type_exp) { [mk_tydecl ~locality td (PTYD_Alias te)] } @@ -1696,6 +1713,29 @@ typedecl: | locality=locality TYPE td=tyd_name EQ te=datatype_def { [mk_tydecl ~locality td (PTYD_Datatype te)] } +(* -------------------------------------------------------------------- *) +(* Type classes *) +typeclass: +| loca=is_local TYPE CLASS tya=tyvars_decl? x=lident + inth=prefix(LTCOLON, plist1(tc_parent, AMP))? + EQ LBRACE body=tc_body RBRACE { + { ptc_name = x; + ptc_params = tya; + ptc_inth = odfl [] inth; + ptc_ops = fst body; + ptc_axs = snd body; + ptc_loca = loca; } + } + +tc_body: +| ops=tc_op* axs=tc_ax* { (ops, axs) } + +tc_op: +| OP x=oident COLON ty=loc(type_exp) { (x, ty) } + +tc_ax: +| AXIOM x=ident COLON ax=form { (x, ax) } + (* -------------------------------------------------------------------- *) (* Subtypes *) subtype: @@ -1714,29 +1754,22 @@ subtype_rename: (* -------------------------------------------------------------------- *) (* Type classes (instances) *) tycinstance: -| loca=is_local INSTANCE x=qident - WITH typ=tyvars_decl? ty=loc(type_exp) ops=tyci_op* axs=tyci_ax* +| loca=is_local INSTANCE tc=tcparam args=tyci_args? + name=prefix(AS, lident)? WITH typ=tyvars_decl? ty=loc(type_exp) ops=tyci_op* axs=tyci_ax* { - { pti_name = x; + let args = args |> omap (fun (c, p) -> `Ring (c, p)) in + { pti_tc = tc; + pti_name = name; pti_type = (odfl [] typ, ty); pti_ops = ops; pti_axs = axs; - pti_args = None; - pti_loca = loca; - } + pti_args = args; + pti_loca = loca; } } -| loca=is_local INSTANCE x=qident c=uoption(UINT) p=uoption(UINT) - WITH typ=tyvars_decl? ty=loc(type_exp) ops=tyci_op* axs=tyci_ax* - { - { pti_name = x; - pti_type = (odfl [] typ, ty); - pti_ops = ops; - pti_axs = axs; - pti_args = Some (`Ring (c, p)); - pti_loca = loca; - } - } +tyci_args: +| c=uoption(UINT) p=uoption(UINT) + { (c, p) } tyci_op: | OP x=oident EQ tg=qoident @@ -1764,8 +1797,9 @@ pred_tydom: tyvars_decl: | LBRACKET tyvars=rlist0(typaram, COMMA) RBRACKET -| LBRACKET tyvars=rlist2(tident, empty) RBRACKET { tyvars } +| LBRACKET tyvars=rlist2(tident, empty) RBRACKET + { List.map (fun x -> (x, [])) tyvars } op_or_const: | OP { `Op } @@ -2492,6 +2526,7 @@ genpattern: simplify_arg: | DELTA l=qoident* { `Delta l } +| CLASS { `DeltaTC } | ZETA { `Zeta } | IOTA { `Iota } | BETA { `Beta } @@ -3933,6 +3968,7 @@ global_action: | sig_def { Ginterface $1 } | typedecl { Gtype $1 } | subtype { Gsubtype $1 } +| typeclass { Gtypeclass $1 } | tycinstance { Gtycinstance $1 } | operator { Goperator $1 } | exception_ { Gexception $1 } diff --git a/src/ecParsetree.ml b/src/ecParsetree.ml index fa83d6e698..7d860ccd6d 100644 --- a/src/ecParsetree.ml +++ b/src/ecParsetree.ml @@ -136,19 +136,21 @@ and 'a rfield = { (* -------------------------------------------------------------------- *) type pmodule_type = pqsymbol -type ptyparam = psymbol +(* -------------------------------------------------------------------- *) +type ptcparam = pqsymbol * pty list +type ptyparam = psymbol * ptcparam list type ptyparams = ptyparam list type ptydname = (ptyparams * psymbol) located type ptydecl = { - pty_name : psymbol; - pty_tyvars : ptyparams; - pty_body : ptydbody; + pty_name : psymbol; + pty_tyvars : ptyparams; + pty_body : ptydbody; pty_locality : locality; } and ptydbody = - | PTYD_Abstract + | PTYD_Abstract of ptcparam list | PTYD_Alias of pty | PTYD_Record of precord | PTYD_Datatype of pdatatype @@ -162,7 +164,6 @@ type f_or_mod_ident = | FM_FunOrVar of pgamepath | FM_Mod of pmsymbol located - type pmod_restr_mem_el = | PMPlus of f_or_mod_ident | PMMinus of f_or_mod_ident @@ -172,7 +173,7 @@ type pmod_restr_mem_el = type pmod_restr_mem = pmod_restr_mem_el list (* -------------------------------------------------------------------- *) -type pmemory = psymbol +type pmemory = psymbol type phoarecmp = EcFol.hoarecmp @@ -441,9 +442,6 @@ type psubtype = { } (* -------------------------------------------------------------------- *) -type ptyvardecls = - psymbol list - type pop_def = | PO_abstr of pty | PO_concr of pty * pformula @@ -465,7 +463,7 @@ type poperator = { po_name : psymbol; po_aliases: psymbol list; po_tags : psymbol list; - po_tyvars : ptyvardecls option; + po_tyvars : ptyparams option; po_args : ptybindings * ptybindings option; po_def : pop_def; po_ax : osymbol_r; @@ -499,7 +497,7 @@ and ppind = ptybindings * (ppind_ctor list) type ppredicate = { pp_name : psymbol; - pp_tyvars : psymbol list option; + pp_tyvars : ptyparams option; pp_def : ppred_def; pp_locality : locality; } @@ -507,7 +505,7 @@ type ppredicate = { (* -------------------------------------------------------------------- *) type pnotation = { nt_name : psymbol; - nt_tv : ptyvardecls option; + nt_tv : ptyparams option; nt_bd : (psymbol * pty) list; nt_args : (psymbol * (psymbol list * pty option)) list; nt_codom : pty; @@ -521,7 +519,7 @@ type abrvopts = (bool * abrvopt) list type pabbrev = { ab_name : psymbol; - ab_tv : ptyvardecls option; + ab_tv : ptyparams option; ab_args : ptybindings; ab_def : pty * pexpr; ab_opts : abrvopts; @@ -562,6 +560,7 @@ type pmpred_args = (osymbol * pformula) list type preduction = { pbeta : bool; (* β-reduction *) pdelta : pqsymbol list option; (* definition unfolding *) + pdeltatc : bool; pzeta : bool; (* let-reduction *) piota : bool; (* case/if-reduction *) peta : bool; (* η-reduction *) @@ -1148,8 +1147,18 @@ type prealize = { } (* -------------------------------------------------------------------- *) +type ptypeclass = { + ptc_name : psymbol; + ptc_params : ptyparams option; + ptc_inth : (ptcparam * (psymbol * psymbol) list) list; + ptc_ops : (psymbol * pty) list; + ptc_axs : (psymbol * pformula) list; + ptc_loca : is_local; +} + type ptycinstance = { - pti_name : pqsymbol; + pti_tc : ptcparam; + pti_name : psymbol option; pti_type : ptyparams * pty; pti_ops : (psymbol * (pty list * pqsymbol)) list; pti_axs : (psymbol * ptactic_core) list; @@ -1346,6 +1355,7 @@ type global_action = | Gaxiom of paxiom | Gtype of ptydecl list | Gsubtype of psubtype + | Gtypeclass of ptypeclass | Gtycinstance of ptycinstance | Gaddrw of (is_local * pqsymbol * pqsymbol list) | Greduction of puserred diff --git a/src/ecPath.ml b/src/ecPath.ml index 0cb0edf4da..8cb456c743 100644 --- a/src/ecPath.ml +++ b/src/ecPath.ml @@ -104,6 +104,9 @@ let rec tostring p = | Psymbol x -> x | Pqname (p,x) -> Printf.sprintf "%s.%s" (tostring p) x +let pp_path fmt p = + Format.fprintf fmt "%s" (tostring p) + let tolist = let rec aux l p = match p.p_node with @@ -394,10 +397,16 @@ let rec m_tostring (m : mpath) = in Printf.sprintf "%s%s%s" top args sub +let pp_mpath fmt p = + Format.fprintf fmt "%s" (m_tostring p) + let x_tostring x = Printf.sprintf "%s./%s" (m_tostring x.x_top) x.x_sub +let pp_xpath fmt x = + Format.fprintf fmt "%s" (x_tostring x) + (* -------------------------------------------------------------------- *) type smsubst = { sms_crt : path Mp.t; diff --git a/src/ecPath.mli b/src/ecPath.mli index ef2d2e8c0f..a34361bc7b 100644 --- a/src/ecPath.mli +++ b/src/ecPath.mli @@ -13,6 +13,8 @@ and path_node = | Psymbol of symbol | Pqname of path * symbol +val pp_path : Format.formatter -> path -> unit + (* -------------------------------------------------------------------- *) val psymbol : symbol -> path val pqname : path -> symbol -> path @@ -62,6 +64,8 @@ and mpath_top = [ | `Local of ident | `Concrete of path * path option ] +val pp_mpath : Format.formatter -> mpath -> unit + (* -------------------------------------------------------------------- *) val mpath : mpath_top -> mpath list -> mpath val mpath_abs : ident -> mpath list -> mpath @@ -96,6 +100,8 @@ type xpath = private { x_tag : int; } +val pp_xpath : Format.formatter -> xpath -> unit + val xpath : mpath -> symbol -> xpath val xastrip : xpath -> xpath diff --git a/src/ecPrinting.ml b/src/ecPrinting.ml index b5f4e20300..6b50d1b871 100644 --- a/src/ecPrinting.ml +++ b/src/ecPrinting.ml @@ -194,8 +194,12 @@ module PPEnv = struct let ty_symb (ppe : t) p = let exists sm = - try EcPath.p_equal (EcEnv.Ty.lookup_path ~unique:true sm ppe.ppe_env) p - with EcEnv.LookupFailure _ -> false + let p1 = Option.map fst (EcEnv.Ty.lookup_opt sm ppe.ppe_env) in + let p2 = Option.map fst (EcEnv.TypeClass.lookup_opt sm ppe.ppe_env) in + + List.exists + (EcPath.p_equal p) + (Option.to_list p1 @ Option.to_list p2) in p_shorten ppe exists (P.toqsymbol p) @@ -206,6 +210,13 @@ module PPEnv = struct in p_shorten ppe exists (P.toqsymbol p) + let tci_symb (ppe : t) p = + let exists sm = + try EcPath.p_equal (EcEnv.TcInstance.lookup_path sm ppe.ppe_env) p + with EcEnv.LookupFailure _ -> false + in + p_shorten ppe exists (P.toqsymbol p) + let rw_symb (ppe : t) p = let exists sm = try EcPath.p_equal (EcEnv.BaseRw.lookup_path sm ppe.ppe_env) p @@ -227,7 +238,7 @@ module PPEnv = struct in p_shorten ppe exists (P.toqsymbol p) - let op_symb (ppe : t) p info = + let op_symb (ppe : t) (p : P.path) (info : ([`Expr | `Form] * etyarg list * dom) option) = let specs = [1, EcPath.pqoname (EcPath.prefix EcCoreLib.CI_Bool.p_eq) "<>"] in let check_for_local sm = @@ -241,13 +252,13 @@ module PPEnv = struct check_for_local sm; EcEnv.Op.lookup_path sm ppe.ppe_env - | Some (mode, typ, dom) -> + | Some (mode, ety, dom) -> let filter = match mode with | `Expr -> fun _ op -> not (EcDecl.is_pred op) | `Form -> fun _ _ -> true in - let tvi = Some (EcUnify.TVIunamed typ) in + let tvi = Some (EcUnify.tvi_unamed ety) in fun sm -> check_for_local sm; @@ -382,7 +393,7 @@ module PPEnv = struct exception FoundUnivarSym of symbol - let tyunivar (ppe : t) i = + let univar (ppe : t) (i : EcUid.uid) = if not (Mint.mem i (fst !(ppe.ppe_univar))) then begin let alpha = "abcdefghijklmnopqrstuvwxyz" in @@ -495,6 +506,14 @@ let pp_paren pp fmt x = let pp_maybe_paren c pp = pp_maybe c pp_paren pp +(* -------------------------------------------------------------------- *) +let pp_bracket pp fmt x = + pp_enclose ~pre:"[" ~post:"]" pp fmt x + +(* -------------------------------------------------------------------- *) +let pp_maybe_bracket c pp = + pp_maybe c pp_bracket pp + (* -------------------------------------------------------------------- *) let pp_string fmt x = Format.fprintf fmt "%s" x @@ -547,8 +566,12 @@ let pp_tyvar ppe fmt x = Format.fprintf fmt "%s" (PPEnv.tyvar ppe x) (* -------------------------------------------------------------------- *) -let pp_tyunivar ppe fmt x = - Format.fprintf fmt "%s" (PPEnv.tyunivar ppe x) +let pp_tyunivar (ppe : PPEnv.t) (fmt : Format.formatter) (a : tyuni) = + Format.fprintf fmt "%s" (PPEnv.univar ppe (a :> EcUid.uid)) + +(* -------------------------------------------------------------------- *) +let pp_tcunivar (ppe : PPEnv.t) (fmt : Format.formatter) (a : tcuni) = + Format.fprintf fmt "%s" (PPEnv.univar ppe (a :> EcUid.uid)) (* -------------------------------------------------------------------- *) let pp_tyname ppe fmt p = @@ -559,6 +582,10 @@ let pp_tcname ppe fmt p = Format.fprintf fmt "%a" EcSymbols.pp_qsymbol (PPEnv.tc_symb ppe p) (* -------------------------------------------------------------------- *) +let pp_tciname ppe fmt p = + Format.fprintf fmt "%a" EcSymbols.pp_qsymbol (PPEnv.tci_symb ppe p) + + (* -------------------------------------------------------------------- *) let pp_rwname ppe fmt p = Format.fprintf fmt "%a" EcSymbols.pp_qsymbol (PPEnv.rw_symb ppe p) @@ -840,7 +867,7 @@ let rec pp_type_r (pp_paren (pp_list ",@ " subpp)) xs (pp_tyname ppe) name in - maybe_paren outer t_prio_name pp fmt (name, tyargs) + maybe_paren outer t_prio_name pp fmt (name, List.map fst tyargs) end | Tfun (t1, t2) -> @@ -1104,28 +1131,94 @@ let tvi_dominated (env : EcEnv.env) (op : EcPath.path) (nargs : int) : bool = List.fold_left (fun acc ty -> Sid.union acc (EcTypes.Tvar.fv ty)) Sid.empty arg_tys in - List.for_all (fun id -> Sid.mem id covered) tparams + List.for_all (fun (id, _) -> Sid.mem id covered) tparams + +(* -------------------------------------------------------------------- *) +let pp_opname fmt (nm, op) = + let op = + if EcCoreLib.is_mixfix_op op then + Printf.sprintf "\"%s\"" op + else if is_binop op then begin + if op.[0] = '*' || op.[String.length op - 1] = '*' + then Format.sprintf "( %s )" op + else Format.sprintf "(%s)" op + end else op + + in EcSymbols.pp_qsymbol fmt (nm, op) + +(* -------------------------------------------------------------------- *) +let rec pp_etyarg (ppe : PPEnv.t) (fmt : Format.formatter) ((ty, tcws) : etyarg) = + match tcws with + | [] -> pp_type ppe fmt ty + | _ -> Format.fprintf fmt "%a[%a]" (pp_type ppe) ty (pp_tcws ppe) tcws + +(* -------------------------------------------------------------------- *) +and pp_etyargs (ppe : PPEnv.t) (fmt : Format.formatter) (etys : etyarg list) = + Format.fprintf fmt "%a" (pp_list ",@ " (pp_etyarg ppe)) etys + +(* -------------------------------------------------------------------- *) +and pp_tcw (ppe : PPEnv.t) (fmt : Format.formatter) (tcw : tcwitness) = + let pp_lift fmt = function + | [] -> () + | l when List.for_all (fun i -> i = 0) l -> + Format.fprintf fmt "^%d" (List.length l) + | l -> + Format.fprintf fmt "^[%a]" + (pp_list ",@ " (fun fmt i -> Format.fprintf fmt "%d" i)) l in + match tcw with + | TCIUni (uid, lift) -> + Format.fprintf fmt "%a%a" (pp_tcunivar ppe) uid pp_lift lift + + | TCIConcrete { path; etyargs; lift } -> + (match etyargs with + | [] -> Format.fprintf fmt "%a%a" (pp_tciname ppe) path pp_lift lift + | _ -> Format.fprintf fmt "%a[%a]%a" + (pp_tciname ppe) path (pp_etyargs ppe) etyargs pp_lift lift) + + | TCIAbstract { support = `Var x; offset; lift } -> + Format.fprintf fmt "%a.`%d%a" (pp_tyvar ppe) x (offset + 1) pp_lift lift + + | TCIAbstract { support = `Abs path; offset; lift } -> + Format.fprintf fmt "%a.`%d%a" (pp_tyname ppe) path (offset + 1) pp_lift lift + +(* -------------------------------------------------------------------- *) +and pp_tcws (ppe : PPEnv.t) (fmt : Format.formatter) (tcws : tcwitness list) = + Format.fprintf fmt "%a" (pp_list ",@ " (pp_tcw ppe)) tcws + +(* -------------------------------------------------------------------- *) +let pp_opname_with_tvi + (ppe : PPEnv.t) + (fmt : Format.formatter) + ((nm, op, tvi) : symbol list * symbol * etyarg list option) += + match tvi with + | None -> + pp_opname fmt (nm, op) + + | Some tvi -> + Format.fprintf fmt "%a<:%a>" + pp_opname (nm, op) (pp_etyargs ppe) tvi (* -------------------------------------------------------------------- *) let pp_opapp - (ppe : PPEnv.t) - (t_ty : 'a -> EcTypes.ty) - ((dt_sub : 'a -> (EcPath.path * _ * 'a list) option), - (pp_sub : PPEnv.t -> (opprec * iassoc) -> 'a pp), - (is_trm : 'a -> bool), - (is_tuple : 'a -> 'a list option), - (is_proj : EcPath.path -> 'a -> (EcIdent.t * int) option)) - (lwr_left : PPEnv.t -> ('a -> EcTypes.ty) -> 'a -> opprec -> int option) - (outer : ((_ * fixity) * iassoc)) - (fmt : Format.formatter) - ((pred : [`Expr | `Form]), - (op : EcPath.path), - (tvi : EcTypes.ty list), - (es : 'a list), - (tyopt : ty option)) + (ppe : PPEnv.t) + (t_ty : 'a -> EcTypes.ty) + ((dt_sub : 'a -> (EcPath.path * _ * 'a list) option), + (pp_sub : PPEnv.t -> opprec * iassoc -> Format.formatter -> 'a -> unit), + (is_trm : 'a -> bool), + (is_tuple : 'a -> 'a list option), + (is_proj : EcPath.path -> 'a -> (EcIdent.t * int) option)) + (lwr_left : PPEnv.t -> ('a -> EcTypes.ty) -> 'a -> + EcSymbols.symbol list -> opprec -> int option) + (outer : symbol list * ((_ * fixity) * iassoc)) + (fmt : Format.formatter) + ((pred : [`Expr | `Form]), + (op : EcPath.path), + (tvi : EcTypes.etyarg list), + (es : 'a list)) = let (nm, opname) = - PPEnv.op_symb ppe op (Some (pred, tvi, (List.map t_ty es, tyopt))) in + PPEnv.op_symb ppe op (Some (pred, tvi, List.map t_ty es)) in let pp_tuple_sub ppe prec fmt e = match is_tuple e with @@ -1171,7 +1264,7 @@ let pp_opapp let rec doit fmt args = match args with | [] -> - maybe_paren outer prio (fun fmt () -> pp fmt) fmt () + maybe_paren (snd outer) prio (fun fmt () -> pp fmt) fmt () | a :: args -> Format.fprintf fmt "%a@ %a" @@ -1195,10 +1288,10 @@ let pp_opapp pp_opname_with_tvi ppe fmt (nm, opname, tvi_opt) | _ -> - let pp_first = fun _ _ fmt op -> - pp_opname_with_tvi ppe fmt (fst op, snd op, tvi_opt) in - let pp fmt () = pp_app ppe ~pp_first ~pp_sub outer fmt ((nm, opname), es) in - maybe_paren outer max_op_prec pp fmt () + let pp_first = fun ppe _ -> pp_opname_with_tvi ppe in + let pp fmt () = + pp_app ppe ~pp_first ~pp_sub (snd outer) fmt (([], opname, tvi_opt), es) + in maybe_paren (snd outer) max_op_prec pp fmt () and try_pp_as_uniop () = match es with @@ -1216,7 +1309,7 @@ let pp_opapp (if is_trm e then "" else " ") (pp_sub ppe (opprio, `NonAssoc)) e in let pp fmt = - maybe_paren outer opprio (fun fmt () -> pp fmt) fmt + maybe_paren (snd outer) opprio (fun fmt () -> pp fmt) fmt in Some pp end @@ -1257,14 +1350,14 @@ let pp_opapp (pp_sub ppe (e_bin_prio_rop4, `Left )) e1 (pp_sub ppe (e_bin_prio_rop4, `Right)) e2 in let opprio_left = - match lwr_left ppe t_ty e2 e_bin_prio_rop4 with + match lwr_left ppe t_ty e2 nm e_bin_prio_rop4 with | None -> e_bin_prio_rop4 | Some n -> if n <= fst e_bin_prio_rop4 then (n, snd e_bin_prio_rop4) else e_bin_prio_rop4 in let pp fmt = - maybe_paren_gen outer (e_bin_prio_rop4, opprio_left) + maybe_paren_gen (snd outer) (e_bin_prio_rop4, opprio_left) (fun fmt () -> pp fmt) fmt in Some pp end @@ -1279,12 +1372,12 @@ let pp_opapp opname (pp_sub ppe (opprio, `Right)) e2 in let opprio_left = - match lwr_left ppe t_ty e2 opprio with + match lwr_left ppe t_ty e2 nm opprio with | None -> opprio | Some n -> if n <= fst opprio then (n, snd opprio) else opprio in let pp fmt = - maybe_paren_gen outer (opprio, opprio_left) + maybe_paren_gen (snd outer) (opprio, opprio_left) (fun fmt () -> pp fmt) fmt in Some pp @@ -1299,8 +1392,8 @@ let pp_opapp let pp_first _ _ fmt opname = let subpp = pp_sub ppe (e_uni_prio_rint, `NonAssoc) in Format.fprintf fmt "%a%s" subpp e opname in - let pp fmt () = pp_app ppe ~pp_first ~pp_sub outer fmt (opname, es) in - Some (maybe_paren outer max_op_prec pp) + let pp fmt () = pp_app ppe ~pp_first ~pp_sub (snd outer) fmt (opname, es) in + Some (maybe_paren (snd outer) max_op_prec pp) end | _ -> @@ -1338,7 +1431,7 @@ let pp_opapp let recp = EcDecl.operator_as_rcrd op in match EcEnv.Ty.by_path_opt recp env with - | Some { tyd_type = Record (_, fields) } + | Some { tyd_type = `Record (_, fields) } when List.length fields = List.length es -> begin let wmap = @@ -1414,7 +1507,7 @@ let pp_opapp (pp_list "@ " (pp_sub ppe (max_op_prec, `NonAssoc))) args in let pp fmt = - maybe_paren outer e_app_prio (fun fmt () -> pp fmt) fmt + maybe_paren (snd outer) e_app_prio (fun fmt () -> pp fmt) fmt in Some pp | _ -> None @@ -1440,7 +1533,7 @@ let pp_chained_orderings (type v) (pp_sub : PPEnv.t -> opprec * iassoc -> v pp) (outer : opprec * iassoc) (fmt : Format.formatter) - ((f, fs) : v * (P.path * ty list * v) list) + ((f, fs) : v * (P.path * etyarg list * v) list) = match fs with | [] -> pp_sub ppe outer fmt f @@ -1451,7 +1544,7 @@ let pp_chained_orderings (type v) ignore (List.fold_left (fun fe (op, tvi, f) -> let (nm, opname) = - PPEnv.op_symb ppe op (Some (`Form, tvi, ([t_ty fe; t_ty f], None))) + PPEnv.op_symb ppe op (Some (`Form, tvi, [t_ty fe; t_ty f])) in Format.fprintf fmt " %t@ %a" (fun fmt -> @@ -1541,7 +1634,7 @@ let pp_locality fmt lc = this function. see maybe_paren_gen for how this precedence is used *) -let lower_left (ppe : PPEnv.t) (t_ty : form -> EcTypes.ty) (f : form) (opprec : opprec) : int option +let lower_left (ppe : PPEnv.t) (t_ty : form -> EcTypes.ty) (f : form) (_nm : symbol list) (opprec : opprec) : int option = let rec l_l f opprec = match f.f_node with @@ -1555,7 +1648,7 @@ let lower_left (ppe : PPEnv.t) (t_ty : form -> EcTypes.ty) (f : form) (opprec : else l_l f2 e_bin_prio_rop4 | Fapp ({f_node = Fop (op, tys)}, [f1; f2]) -> (let (_, opname) = - PPEnv.op_symb ppe op (Some (`Form, tys, (List.map t_ty [f1; f2], None))) in + PPEnv.op_symb ppe op (Some (`Form, tys, List.map t_ty [f1; f2])) in match priority_of_binop opname with | None -> None | Some opprec' -> @@ -1721,9 +1814,9 @@ and try_pp_chained_orderings let as_ordering (f : form) = match match_pp_notations ~filter:(fun (p, _) -> is_ordering_op p) ppe f with - | Some ((op, (tvi, _)), ue, ev, ov, [i1; i2]) -> begin - let ti = Tvar.subst ov in - let tvi = List.map (ti -| tvar) tvi in + | Some ((op, (tvi, _)), ue, ev, (ov : EcUnify.UniEnv.opened), [i1; i2]) -> begin + let ti = Tvar.subst ov.subst in + let tvi = List.map (fun (t, _) -> (ti (tvar t), [])) tvi in let sb = EcMatching.MEV.assubst ue ev ppe.ppe_env in let i1 = Fsubst.f_subst sb i1 in let i2 = Fsubst.f_subst sb i2 in @@ -1756,11 +1849,10 @@ and try_pp_chained_orderings Option.fold ~none:(i1, acc) ~some:(collect acc (Some i1)) f1 in match collect [] None f with - | (_, ([] | [_])) -> - false - + | (_, ([] | [_])) -> false | (f, fs) -> - pp_chained_orderings ppe f_ty pp_form_r outer fmt (f, fs); + pp_chained_orderings + ppe f_ty pp_form_r outer fmt (f, fs); true | exception Bailout -> @@ -1810,15 +1902,13 @@ and match_pp_notations let a1, a2 = List.split_at na a in f_app f a1 (toarrow (List.map f_ty a2) oty), a2 else f_app f a oty, [] in - - let ev = MEV.of_idents (List.map fst nt.ont_args) `Form in - let ue = EcUnify.UniEnv.create None in - let ov = EcUnify.UniEnv.opentvi ue tv None in - let hy = EcEnv.LDecl.init ppe.PPEnv.ppe_env [] in - let bd = match (EcEnv.Memory.get_active_ss ppe.PPEnv.ppe_env) with - | None -> form_of_expr nt.ont_body - | Some m -> (ss_inv_of_expr m nt.ont_body).inv in - let bd = Fsubst.f_subst_tvar ~freshen:true ov bd in + let ev = MEV.of_idents (List.map fst nt.ont_args) `Form in + let ue = EcUnify.UniEnv.create None in + let ov = EcUnify.UniEnv.opentvi ue tv None in + let hy = EcEnv.LDecl.init ppe.PPEnv.ppe_env [] in + let mr = odfl (EcIdent.create "&hr") (EcEnv.Memory.get_active_ss ppe.PPEnv.ppe_env) in + let bd = form_of_expr ~m:mr nt.ont_body in + let bd = Fsubst.f_subst_tvar ~freshen:true ov.subst bd in try let (ue, ev) = @@ -1858,15 +1948,28 @@ and try_pp_notations | None -> false - | Some ((p, (tv, nt)), ue, ev, ov, eargs) -> - let ti = Tvar.subst ov in - let rty = ti nt.ont_resty in - let tv = List.map (ti -| tvar) tv in - let args = List.map (curry f_local -| snd_map ti) nt.ont_args in - let args = - let subst = EcMatching.MEV.assubst ue ev ppe.ppe_env in - List.map (Fsubst.f_subst subst) args in - let f = f_app (f_op p tv rty) (args @ eargs) f.f_ty in + | Some ((p, (_tv, nt)), ue, ev, (ov : EcUnify.UniEnv.opened), eargs) -> + let ti = Tvar.subst ov.subst in + (* After [f_match_core], the abbrev's tparam univars (created by + [opentvi] in [ov.subst]) have been bound by the matcher. Chase + those bindings through the unienv so the displayed [tv] / + [resty] / [args] show concrete carriers (e.g. [c]) rather than + the fresh univars [#a, #b, ...] that [ov.subst] alone would + produce. + + Use [ov.args] (the [etyarg list] from [opentvi], which carries + both the type univar AND its TC-witness univar(s)) instead of + just the bare tparams; chasing through [mev_subst] then + resolves both the type univars AND the TC-witness univars + into their committed forms, so the printed notation shows + both the carrier ([c]) and its TC witness when one exists. *) + let mev_subst = EcMatching.MEV.assubst ue ev ppe.ppe_env in + let chase ty = EcCoreSubst.ty_subst mev_subst (ti ty) in + let rty = chase nt.ont_resty in + let tv = List.map (EcCoreSubst.etyarg_subst mev_subst) ov.args in + let args = List.map (curry f_local -| snd_map chase) nt.ont_args in + let args = List.map (Fsubst.f_subst mev_subst) args in + let f = f_app (f_op_tc p tv rty) (args @ eargs) f.f_ty in pp_form_core_r ppe outer fmt f; true and pp_poe (ppe : PPEnv.t) (fmt : Format.formatter) (poe : form Mop.t) = @@ -1900,7 +2003,7 @@ and pp_form_core_r (f : form) = let pp_opapp ppe (outer : opprec * iassoc) (fmt : Format.formatter) - (op, tys, es, tyopt) = + (op, tys, es, _tyopt) = let rec dt_sub f = match destr_app f with | ({ f_node = Fop (p, tvi) }, args) -> Some (p, tvi, args) @@ -1930,7 +2033,7 @@ and pp_form_core_r in pp_opapp ppe f_ty (dt_sub, pp_form_r, is_trm, is_tuple, is_proj) - lower_left outer fmt (`Form, op, tys, es, tyopt) + lower_left ([], outer) fmt (`Form, op, tys, es) in match f.f_node with @@ -2135,8 +2238,9 @@ and pp_form_core_r (string_of_hcmp hs.bhs_cmp) (pp_form_r ppef (max_op_prec,`NonAssoc)) (bhs_bd hs).inv - | Fpr pr-> + | Fpr pr -> let me = EcEnv.Fun.prF_memenv pr.pr_event.m pr.pr_fun ppe.PPEnv.ppe_env in + let ppep = PPEnv.create_and_push_mem ppe ~active:true me in let pm = debug_mode || pr.pr_event.m.id_symb <> "&hr" in Format.fprintf fmt "Pr[@[%a@[%t@] %a@@ %a :@ %a@]]" @@ -2152,12 +2256,7 @@ and pp_form_core_r (pp_local ppe) pr.pr_mem (pp_form ppep) pr.pr_event.inv -and pp_form_r - (ppe : PPEnv.t) - (outer : opprec * iassoc) - (fmt : Format.formatter) - (f : form) -= +and pp_form_r (ppe : PPEnv.t) outer fmt f = let printers = [try_pp_notations; try_pp_form_eqveq; @@ -2351,7 +2450,7 @@ let pp_sform ppe fmt f = (* -------------------------------------------------------------------- *) let pp_typedecl (ppe : PPEnv.t) fmt (x, tyd) = let ppe = PPEnv.enter_theory ppe (Option.get (EcPath.prefix x)) in - let ppe = PPEnv.add_locals ppe tyd.tyd_params in + let ppe = PPEnv.add_locals ppe (List.map fst tyd.tyd_params) in let name = P.basename x in let pp_prelude fmt = @@ -2359,22 +2458,34 @@ let pp_typedecl (ppe : PPEnv.t) fmt (x, tyd) = | [] -> Format.fprintf fmt "type %s" name - | [tx] -> + | [(tx, _)] -> Format.fprintf fmt "type %a %s" (pp_tyvar ppe) tx name | txs -> Format.fprintf fmt "type %a %s" - (pp_paren (pp_list ",@ " (pp_tyvar ppe))) txs name + (pp_paren (pp_list ",@ " (pp_tyvar ppe))) (List.map fst txs) name and pp_body fmt = + let pp_one_tc fmt (tc : typeclass) = + match tc.tc_args with + | [] -> pp_tyname ppe fmt tc.tc_name + | [ty] -> + Format.fprintf fmt "%a %a" + (pp_type ppe) (fst ty) (pp_tyname ppe) tc.tc_name + | tys -> + Format.fprintf fmt "(%a) %a" + (pp_list ",@ " (pp_type ppe)) (List.fst tys) + (pp_tyname ppe) tc.tc_name in match tyd.tyd_type with - | Abstract -> - () + | `Abstract [] -> () + | `Abstract tcs -> + Format.fprintf fmt " <: %a" + (pp_list " &@ " pp_one_tc) tcs - | Concrete ty -> + | `Concrete ty -> Format.fprintf fmt " =@ %a" (pp_type ppe) ty - | Datatype { tydt_ctors = cs } -> + | `Datatype { tydt_ctors = cs } -> let pp_ctor fmt (c, cty) = match cty with | [] -> @@ -2385,7 +2496,7 @@ let pp_typedecl (ppe : PPEnv.t) fmt (x, tyd) = in Format.fprintf fmt " =@ [@[%a@]]" (pp_list " |@ " pp_ctor) cs - | Record (_, fields) -> + | `Record (_, fields) -> let pp_field fmt (f, fty) = Format.fprintf fmt "%s: @[%a@]" f (pp_type ppe) fty in @@ -2394,11 +2505,36 @@ let pp_typedecl (ppe : PPEnv.t) fmt (x, tyd) = in Format.fprintf fmt "@[%a%t%t.@]" pp_locality tyd.tyd_loca pp_prelude pp_body +(* -------------------------------------------------------------------- *) +let pp_typeclass (ppe : PPEnv.t) fmt tc = + match tc.tc_args with + | [] -> + pp_tyname ppe fmt tc.tc_name + + | [ty] -> + Format.fprintf fmt "%a %a" + (pp_type ppe) (fst ty) + (pp_tyname ppe) tc.tc_name + + | tys -> + Format.fprintf fmt "(%a) %a" + (pp_list ",@ " (pp_type ppe)) (List.map fst tys) + (pp_tyname ppe) tc.tc_name + +(* -------------------------------------------------------------------- *) +let pp_tyvar_ctt (ppe : PPEnv.t) fmt (tvar, ctt) = + match ctt with + | [] -> pp_tyvar ppe fmt tvar + | ctt -> + Format.fprintf fmt "%a <: %a" + (pp_tyvar ppe) tvar + (pp_list " &@ " (fun fmt tc -> pp_typeclass ppe fmt tc)) ctt + (* -------------------------------------------------------------------- *) let pp_tyvarannot (ppe : PPEnv.t) fmt (ids: ty_param list) = match ids with | [] -> () - | ids -> Format.fprintf fmt "[%a]" (pp_list ",@ " (pp_tyvar ppe)) ids + | ids -> Format.fprintf fmt "[%a]" (pp_list ",@ " (pp_tyvar_ctt ppe)) ids let pp_pvar (ppe : PPEnv.t) fmt ids = match ids with @@ -2480,8 +2616,19 @@ let pp_codepos_path ppe = (pp_list "" (pp_codepos_step ppe)) (* -------------------------------------------------------------------- *) -let pp_codepos (ppe : PPEnv.t) (fmt : Format.formatter) ((cpath, cp1) : CP.codepos) = - Format.fprintf fmt "%a%a" (pp_codepos_path ppe) cpath (pp_codepos1 ppe) cp1 +let pp_codepos (ppe : PPEnv.t) (fmt : Format.formatter) ((nm, cp1) : CP.codepos) = + let pp_nm (fmt : Format.formatter) ((cp, bs) : CP.codepos1 * CP.codepos_brsel) = + let bs = + match bs with + | `Cond true -> "." + | `Cond false -> "?" + | `Match cp -> Format.sprintf "#%s." cp + | `MatchByPos i -> Format.sprintf "#%d." i + in + Format.fprintf fmt "%a%s" (pp_codepos1 ppe) cp bs + in + + Format.fprintf fmt "%a%a" (pp_list "" pp_nm) nm (pp_codepos1 ppe) cp1 (* -------------------------------------------------------------------- *) let pp_codegap1 (ppe : PPEnv.t) (fmt : Format.formatter) (g : CP.codegap1) = @@ -2509,7 +2656,7 @@ let pp_codegap_range (ppe: PPEnv.t) (fmt: Format.formatter) ((cpath, cp1r) : CP. (* -------------------------------------------------------------------- *) let pp_opdecl_pr (ppe : PPEnv.t) fmt ((basename, ts, ty, op): symbol * ty_param list * ty * prbody option) = - let ppe = PPEnv.add_locals ppe ts in + let ppe = PPEnv.add_locals ppe (List.map fst ts) in let pp_body fmt = match op with @@ -2574,8 +2721,8 @@ let pp_exception_decl (ppe: PPEnv.t) fmt basename ty = pp_opname ([], basename) pp_body (* -------------------------------------------------------------------- *) -let pp_opdecl_op (ppe : PPEnv.t) fmt (basename, ts, ty, op) = - let ppe = PPEnv.add_locals ppe ts in +let pp_opdecl_op (ppe : PPEnv.t) fmt ((basename, ts, ty, op) : symbol * ty_param list * ty * _) = + let ppe = PPEnv.add_locals ppe (List.map fst ts) in let pp_body fmt = match op with @@ -2649,8 +2796,9 @@ let pp_opdecl_op (ppe : PPEnv.t) fmt (basename, ts, ty, op) = (pp_type ppe) fix.opf_resty (pp_list "@\n" pp_branch) cfix - | Some (OP_TC) -> - Format.fprintf fmt "= < type-class-operator >" + | Some (OP_TC (path, name)) -> + Format.fprintf fmt ": %a = < type-class operator `%s' of `%a'>" + (pp_type ppe) ty name (pp_tyname ppe) path | Some (OP_Exn _) -> Format.fprintf fmt "= < exception >" @@ -2667,7 +2815,7 @@ let pp_opdecl_op (ppe : PPEnv.t) fmt (basename, ts, ty, op) = let pp_opdecl_nt (ppe : PPEnv.t) fmt ((basename, ts, _ty, nt) : symbol * ty_param list * ty * notation) = - let ppe = PPEnv.add_locals ppe ts in + let ppe = PPEnv.add_locals ppe (List.map fst ts) in let pp_body fmt = let subppe, pplocs = @@ -2716,7 +2864,7 @@ let pp_opdecl in Format.fprintf fmt "@[%a%a%a@]" pp_locality op.op_loca pp_name x pp_decl op let pp_added_op (ppe : PPEnv.t) fmt op = - let ppe = PPEnv.add_locals ppe op.op_tparams in + let ppe = PPEnv.add_locals ppe (List.map fst op.op_tparams) in match op.op_tparams with | [] -> Format.fprintf fmt ": @[%a@]" (pp_type ppe) op.op_ty @@ -2738,7 +2886,7 @@ let tags_of_axkind = function | `Lemma -> [] let pp_axiom ?(long=false) (ppe : PPEnv.t) fmt (x, ax) = - let ppe = PPEnv.add_locals ppe ax.ax_tparams in + let ppe = PPEnv.add_locals ppe (List.map fst ax.ax_tparams) in let basename = P.basename x in let pp_spec fmt = @@ -3281,8 +3429,8 @@ let pp_equivS (ppe : PPEnv.t) ?prpo fmt es = let insync = EcMemory.mt_equal (snd es.es_ml) (snd es.es_mr) - && EcReduction.EqTest.for_stmt - ppe.PPEnv.ppe_env ~norm:false es.es_sl es.es_sr in +(* && EcReduction.EqTest.for_stmt + ppe.PPEnv.ppe_env ~norm:false es.es_sl es.es_sr in *) in let ppnode = if insync then begin @@ -3317,6 +3465,50 @@ let pp_rwbase ppe fmt (p, rws) = Format.fprintf fmt "%a = %a@\n%!" (pp_rwname ppe) p (pp_list ", " (pp_axname ppe)) (Sp.elements rws) +(* -------------------------------------------------------------------- *) +let pp_tparam ppe fmt (id, tcs) = + Format.fprintf fmt "%a <: %a" + pp_symbol (EcIdent.name id) + (pp_list " &@ " (pp_typeclass ppe)) tcs + +let pp_tparams ppe fmt tparams = + Format.fprintf fmt "%a" + (pp_maybe (List.length tparams != 0) (pp_enclose ~pre:"[" ~post:"] ") (pp_list ",@ " (pp_tparam ppe))) tparams + +let pp_prts ppe fmt = function + | [] -> () + | tcs -> + let pp_one fmt (p, _ren) = pp_typeclass ppe fmt p in + Format.fprintf fmt " <: %a" + (pp_list "@ & " pp_one) tcs + +let pp_op ppe fmt (t, ty) = + Format.fprintf fmt " @[op %s :@ %a.@]" + (EcIdent.name t) + (pp_type ppe) ty + +let pp_ops ppe fmt ops = + pp_maybe (List.length ops != 0) (pp_enclose ~pre:"" ~post:"@,@,") (pp_list "@,@," (pp_op ppe)) fmt ops + +let pp_ax ppe fmt (s, f) = + Format.fprintf fmt " @[axiom %s :@ %a.@]" + s (pp_form ppe) f + +let pp_axs ppe fmt axs = + pp_maybe (List.length axs != 0) (pp_enclose ~pre:"" ~post:"@,@,") (pp_list "@,@," (pp_ax ppe)) fmt axs + +let pp_ops_axs ppe fmt (ops, axs) = + Format.fprintf fmt "%a%a" + (pp_maybe (List.length ops + List.length axs != 0) (pp_enclose ~pre:"@,@," ~post:"") (pp_ops ppe)) ops + (pp_axs ppe) axs + +let pp_tc_decl ppe fmt (p, tcdecl) = + Format.fprintf fmt "@[type class %a%a%a = {%a}.@]" + (pp_tparams ppe) tcdecl.tc_tparams + (pp_tyname ppe) p + (pp_prts ppe) tcdecl.tc_prts + (pp_ops_axs ppe) (tcdecl.tc_ops, tcdecl.tc_axs) + (* -------------------------------------------------------------------- *) let pp_solvedb ppe fmt (db: (int * (P.path * _) list) list) = List.iter (fun (lvl, ps) -> @@ -3409,7 +3601,7 @@ module PPGoal = struct in (ppe, (id, pdk)) let pp_goal1 ?(pphyps = true) ?prpo ?(idx) (ppe : PPEnv.t) fmt (hyps, concl) = - let ppe = PPEnv.add_locals ppe hyps.EcBaseLogic.h_tvar in + let ppe = PPEnv.add_locals ppe (List.map fst hyps.EcBaseLogic.h_tvar) in let ppe, pps = List.map_fold pre_pp_hyp ppe (List.rev hyps.EcBaseLogic.h_local) in idx |> oiter (Format.fprintf fmt "Goal #%d@\n"); @@ -3420,7 +3612,7 @@ module PPGoal = struct | [] -> Format.fprintf fmt "Type variables: @\n\n%!" | tv -> Format.fprintf fmt "Type variables: %a@\n\n%!" - (pp_list ", " (pp_tyvar ppe)) tv + (pp_list ", " (pp_tyvar ppe)) (List.map fst tv) end; List.iter (fun (id, (pk, dk)) -> let pk fmt = @@ -3455,7 +3647,7 @@ end (* -------------------------------------------------------------------- *) let pp_hyps (ppe : PPEnv.t) fmt hyps = let hyps = EcEnv.LDecl.tohyps hyps in - let ppe = PPEnv.add_locals ppe hyps.EcBaseLogic.h_tvar in + let ppe = PPEnv.add_locals ppe (List.map fst hyps.EcBaseLogic.h_tvar) in let ppe, pps = List.map_fold PPGoal.pre_pp_hyp ppe (List.rev hyps.EcBaseLogic.h_local) in @@ -3464,7 +3656,7 @@ let pp_hyps (ppe : PPEnv.t) fmt hyps = | [] -> Format.fprintf fmt "Type variables: @\n\n%!" | tv -> Format.fprintf fmt "Type variables: %a@\n\n%!" - (pp_list ", " (pp_tyvar ppe)) tv + (pp_list ", " (pp_tyvar ppe)) (List.map fst tv) end; List.iter (fun (id, (pk, dk)) -> let pk fmt = @@ -3615,7 +3807,7 @@ let rec pp_instr_r (ppe : PPEnv.t) fmt i = let pp_branch fmt ((vars, s), (cname, _)) = let ptn = EcTypes.toarrow (List.snd vars) e.e_ty in - let ptn = f_op (EcPath.pqoname (EcPath.prefix p) cname) typ ptn in + let ptn = f_op_tc (EcPath.pqoname (EcPath.prefix p) cname) typ ptn in let ptn = f_app ptn (List.map (fun (x, ty) -> f_local x ty) vars) e.e_ty in Format.fprintf fmt "| %a => @[%a@]@ " @@ -3770,10 +3962,13 @@ let rec pp_theory ppe (fmt : Format.formatter) (path, cth) = Format.fprintf fmt "export %a." EcSymbols.pp_qsymbol (PPEnv.th_symb ppe p) - | EcTheory.Th_instance ((typ, ty), tc, lc) -> begin - let ppe = PPEnv.add_locals ppe typ in (* FIXME *) + | EcTheory.Th_typeclass _ -> + Format.fprintf fmt "typeclass ." + + | EcTheory.Th_instance (_, tci) -> begin + let ppe = PPEnv.add_locals ppe (List.map fst tci.tci_params) in - match tc with + match tci.tci_instance with | (`Ring _ | `Field _) as tc -> begin let (name, ops) = let rec ops_of_ring cr = @@ -3809,10 +4004,10 @@ let rec pp_theory ppe (fmt : Format.formatter) (path, cth) = in Format.fprintf fmt "%ainstance %s with [%a] %a@\n@[ %a@]" - pp_locality lc + pp_locality tci.tci_local name - (pp_paren (pp_list ",@ " (pp_tyvar ppe))) typ - (pp_type ppe) ty + (pp_paren (pp_list ",@ " (pp_tyvar ppe))) (List.map fst tci.tci_params) + (pp_type ppe) tci.tci_type (pp_list "@\n" (fun fmt (name, op) -> Format.fprintf fmt "op %s = %s" @@ -3820,9 +4015,11 @@ let rec pp_theory ppe (fmt : Format.formatter) (path, cth) = ops end - | `General p -> + | `General (tc, _) -> Format.fprintf fmt "%ainstance %a with %a." - pp_locality lc (pp_type ppe) ty pp_path p + pp_locality tci.tci_local + (pp_type ppe) tci.tci_type + (pp_typeclass ppe) tc end | EcTheory.Th_baserw (name, _lc) -> @@ -4043,6 +4240,12 @@ module ObjectInfo = struct | `Rewrite name -> pr_rw fmt env name | `Solve name -> pr_at fmt env name + (* ------------------------------------------------------------------ *) + let pr_tc_r = + { od_name = "type classes"; + od_lookup = EcEnv.TypeClass.lookup; + od_printer = pp_tc_decl; } + (* ------------------------------------------------------------------ *) let pr_any fmt env qs = let printers = [pr_gen_r ~prcat:true pr_ty_r ; @@ -4052,7 +4255,8 @@ module ObjectInfo = struct pr_gen_r ~prcat:true pr_mod_r; pr_gen_r ~prcat:true pr_mty_r; pr_gen_r ~prcat:true pr_rw_r ; - pr_gen_r ~prcat:true pr_at_r ; ] in + pr_gen_r ~prcat:true pr_at_r ; + pr_gen_r ~prcat:true pr_tc_r ; ] in let ok = ref (List.length printers) in diff --git a/src/ecProcSem.ml b/src/ecProcSem.ml index 2fb1af20f5..c997d47005 100644 --- a/src/ecProcSem.ml +++ b/src/ecProcSem.ml @@ -416,7 +416,7 @@ and translate_e (env : senv) (e : expr) = raise SemNotSupported | _ -> - e_map (fun x -> x) (translate_e env) e + e_map (fun ty -> ty) (translate_e env) e (* -------------------------------------------------------------------- *) and translate_lv (env : senv) (lv : lvalue) : lpattern = diff --git a/src/ecProofTerm.ml b/src/ecProofTerm.ml index 4235c079a8..2bb225d7f4 100644 --- a/src/ecProofTerm.ml +++ b/src/ecProofTerm.ml @@ -120,8 +120,8 @@ let concretize_e_form_gen (CPTEnv subst) ids f = f_forall ids f (* -------------------------------------------------------------------- *) -let concretize_e_form cptenv f = - concretize_e_form_gen cptenv [] f +let concretize_e_form (CPTEnv subst) f = + Fsubst.f_subst subst f (* -------------------------------------------------------------------- *) let rec concretize_e_arg ((CPTEnv subst) as cptenv) arg = @@ -137,7 +137,7 @@ and concretize_e_head ((CPTEnv subst) as cptenv) head = | PTCut (f, s) -> PTCut (Fsubst.f_subst subst f, s) | PTHandle h -> PTHandle h | PTLocal x -> PTLocal x - | PTGlobal (p, tys) -> PTGlobal (p, List.map (ty_subst subst) tys) + | PTGlobal (p, tys) -> PTGlobal (p, List.map (EcCoreSubst.etyarg_subst subst) tys) | PTTerm pt -> PTTerm (concretize_e_pt cptenv pt) and concretize_e_pt ((CPTEnv subst) as cptenv) pt = @@ -191,23 +191,31 @@ let pt_of_hyp_r ptenv x = ptev_ax = ax; } (* -------------------------------------------------------------------- *) -let pt_of_global pf hyps p tys = +let pt_of_global_tc pf hyps p etyargs = let ptenv = ptenv_of_penv hyps pf in - let ax = EcEnv.Ax.instantiate p tys (LDecl.toenv hyps) in + let ax = EcEnv.Ax.instanciate p etyargs (LDecl.toenv hyps) in { ptev_env = ptenv; - ptev_pt = ptglobal ~tys p; + ptev_pt = ptglobal ~tys:etyargs p; ptev_ax = ax; } (* -------------------------------------------------------------------- *) -let pt_of_global_r ptenv p tys = +let pt_of_global pf hyps p tys = + pt_of_global_tc pf hyps p (List.map (fun ty -> (ty, [])) tys) + +(* -------------------------------------------------------------------- *) +let pt_of_global_tc_r ptenv p etyargs = let env = LDecl.toenv ptenv.pte_hy in - let ax = EcEnv.Ax.instantiate p tys env in + let ax = EcEnv.Ax.instanciate p etyargs env in { ptev_env = ptenv; - ptev_pt = ptglobal ~tys p; + ptev_pt = ptglobal ~tys:etyargs p; ptev_ax = ax; } +(* -------------------------------------------------------------------- *) +let pt_of_global_r ptenv p tys = + pt_of_global_tc_r ptenv p (List.map (fun ty -> (ty, [])) tys) + (* -------------------------------------------------------------------- *) let pt_of_handle_r ptenv hd = let g = FApi.get_pregoal_by_id hd ptenv.pte_pe in @@ -222,13 +230,11 @@ let pt_of_uglobal_r ptenv p = let ax = oget (EcEnv.Ax.by_path_opt p env) in let typ, ax = (ax.EcDecl.ax_tparams, ax.EcDecl.ax_spec) in - (* FIXME: TC HOOK *) let fs = EcUnify.UniEnv.opentvi ptenv.pte_ue typ None in - let ax = Fsubst.f_subst_tvar ~freshen:true fs ax in - let typ = List.map (fun a -> EcIdent.Mid.find a fs) typ in + let ax = Fsubst.f_subst_tvar ~freshen:true fs.subst ax in { ptev_env = ptenv; - ptev_pt = ptglobal ~tys:typ p; + ptev_pt = ptglobal ~tys:fs.args p; ptev_ax = ax; } (* -------------------------------------------------------------------- *) @@ -313,21 +319,73 @@ let pf_find_occurence | _, _ -> false in + (* Two heads match keywise iff they're path-equal, OR the candidate's + head TC-reduces (via factory rename on its abstract witness) to an + [Fop] with the pattern's key. Without the second clause, [rewrite L] + misses positions where [L]'s LHS uses a class op like [( * )<:comring>] + and the goal has the rename-equivalent [(+)<:t mulmonoid leg>] — + deeper matching would resolve them, but [keycheck] would have + filtered them out first. *) + let env_for_kmatch = EcEnv.LDecl.toenv pt.pte_hy in + let head_op_after_tc_reduce (head : form) : EcPath.path option = + match head.f_node with + | Fop (p, tys) -> begin + match EcEnv.Op.tc_reduce env_for_kmatch p tys with + | exception EcEnv.NotReducible -> None + | reduced -> begin + match (fst (destr_app reduced)).f_node with + | Fop (p', _) -> Some p' + | _ -> None + end + end + | _ -> None in + (* Reverse-instance lookup: given a TC class op [tcop] and a + concrete op [concrete], is there a registered instance where + [tcop]'s realisation is [concrete]? Used by [kmatch] when the + pattern is a TC-op call with a univar carrier (so [tc_reduce] + can't fire forward) and the goal's head is the concrete + realisation. Lets [rewrite mul0r] (no TVI) match positions + whose head is e.g. [polyM] — pinning the carrier via + [try_delta] / [doit_tc_reduce] downstream. *) + let tc_op_realised_by tcop concrete = + EcEnv.Op.tc_op_realised_by env_for_kmatch tcop concrete in + (* Compute the alternative head an [Fop p tys] could expose after a + single [tc_reduce] step at the carrier. Used for both pattern- and + goal-side keyed matching. *) + let kmatch_alt_head (head : form) : EcPath.path option = + head_op_after_tc_reduce head in let kmatch key tp = - match key, (fst (destr_app tp)).f_node with + let tp_head = fst (destr_app tp) in + match key, tp_head.f_node with | `NoKey , _ -> true - | `Path p, Fop (p', _) -> EcPath.p_equal p p' - | `Path _, _ -> false + | `Path p, Fop (p', _) when EcPath.p_equal p p' -> true + | `Path p, Fop (p', _) when tc_op_realised_by p p' -> true + | `Path p, _ -> begin + match kmatch_alt_head tp_head with + | Some p' -> EcPath.p_equal p p' + | None -> false + end | `Var x, Flocal x' -> id_equal x x' | `Var _, _ -> false in let keycheck tp key = not occmode.k_keyed || kmatch key tp in - (* Extract key from pattern *) + (* Extract key from pattern. For a TC-op pattern, take the *reduced* + head as the key when [tc_reduce] yields a concrete op at the + pattern's carrier — that's the form most goals will have after + abbrev expansion at that carrier. Without this, [rewrite L] with + [L] using a class op like [(+)<:int poly>] would key on [(+)] + and miss goals where the same position has been elaborated to + the carrier's structural realisation (e.g. [polyD]). *) let key = - match (fst (destr_app ptn)).f_node with - | Fop (p, _) -> `Path p + let ptn_head = fst (destr_app ptn) in + match ptn_head.f_node with + | Fop (p, _) -> begin + match kmatch_alt_head ptn_head with + | Some p' -> `Path p' + | None -> `Path p + end | Flocal x -> if is_none (EcMatching.MEV.get x `Form !(pt.pte_ev)) then `Var x @@ -514,14 +572,12 @@ let process_named_pterm pe (tvi, fp) = (fun () -> omap (EcTyping.transtvi env pe.pte_ue) tvi) in - PT.pf_check_tvi pe.pte_pe typ tvi; + PT.pf_check_tvi env pe.pte_pe typ tvi; - (* FIXME: TC HOOK *) let fs = EcUnify.UniEnv.opentvi pe.pte_ue typ tvi in - let ax = Fsubst.f_subst_tvar ~freshen:false fs ax in - let typ = List.map (fun a -> EcIdent.Mid.find a fs) typ in + let ax = Fsubst.f_subst_tvar ~freshen:false fs.subst ax in - (p, (typ, ax)) + (p, (fs.args, ax)) (* ------------------------------------------------------------------ *) let process_pterm_cut ~prcut pe pt = @@ -918,7 +974,7 @@ let tc1_process_full_closed_pterm (tc : tcenv1) (ff : ppterm) = (* -------------------------------------------------------------------- *) type prept = [ | `Hy of EcIdent.t - | `G of EcPath.path * ty list + | `G of EcPath.path * etyarg list | `UG of EcPath.path | `HD of handle | `PE of pt_ev @@ -937,8 +993,8 @@ and prept_arg = [ let pt_of_prept_r (ptenv : pt_env) : prept -> pt_ev = let rec build_pt : prept -> pt_ev = function | `Hy id -> pt_of_hyp_r ptenv id - | `G (p, tys) -> pt_of_global_r ptenv p tys - | `UG p -> pt_of_global_r ptenv p [] + | `G (p, tys) -> pt_of_global_tc_r ptenv p tys + | `UG p -> pt_of_global_tc_r ptenv p [] | `HD hd -> pt_of_handle_r ptenv hd | `PE pe -> pe | `App (pt, args) -> List.fold_left app_pt_ev (build_pt pt) args diff --git a/src/ecProofTerm.mli b/src/ecProofTerm.mli index af3d0509fe..2045187f6d 100644 --- a/src/ecProofTerm.mli +++ b/src/ecProofTerm.mli @@ -154,12 +154,13 @@ val ptenv : proofenv -> LDecl.hyps -> (EcUnify.unienv * mevmap) -> pt_env val copy : pt_env -> pt_env (* Proof-terms construction from components *) -val pt_of_hyp : proofenv -> LDecl.hyps -> EcIdent.t -> pt_ev -val pt_of_global_r : pt_env -> EcPath.path -> ty list -> pt_ev -val pt_of_global : proofenv -> LDecl.hyps -> EcPath.path -> ty list -> pt_ev -val pt_of_uglobal_r : pt_env -> EcPath.path -> pt_ev -val pt_of_uglobal : proofenv -> LDecl.hyps -> EcPath.path -> pt_ev - +val pt_of_hyp : proofenv -> LDecl.hyps -> EcIdent.t -> pt_ev +val pt_of_global_tc_r : pt_env -> EcPath.path -> etyarg list -> pt_ev +val pt_of_global_tc : proofenv -> LDecl.hyps -> EcPath.path -> etyarg list -> pt_ev +val pt_of_global_r : pt_env -> EcPath.path -> ty list -> pt_ev +val pt_of_global : proofenv -> LDecl.hyps -> EcPath.path -> ty list -> pt_ev +val pt_of_uglobal_r : pt_env -> EcPath.path -> pt_ev +val pt_of_uglobal : proofenv -> LDecl.hyps -> EcPath.path -> pt_ev (* -------------------------------------------------------------------- *) val ffpattern_of_genpattern : LDecl.hyps -> genpattern -> ppterm option @@ -167,7 +168,7 @@ val ffpattern_of_genpattern : LDecl.hyps -> genpattern -> ppterm option (* -------------------------------------------------------------------- *) type prept = [ | `Hy of EcIdent.t - | `G of EcPath.path * ty list + | `G of EcPath.path * etyarg list | `UG of EcPath.path | `HD of handle | `PE of pt_ev @@ -190,7 +191,7 @@ module Prept : sig val (@) : prept -> prept_arg list -> prept val hyp : EcIdent.t -> prept - val glob : EcPath.path -> ty list -> prept + val glob : EcPath.path -> etyarg list -> prept val uglob : EcPath.path -> prept val hdl : handle -> prept diff --git a/src/ecProofTyping.ml b/src/ecProofTyping.ml index 6dffd0f6d9..639437eb0f 100644 --- a/src/ecProofTyping.ml +++ b/src/ecProofTyping.ml @@ -1,14 +1,13 @@ (* -------------------------------------------------------------------- *) open EcUtils open EcIdent +open EcAst open EcTypes open EcPath open EcFol open EcEnv open EcCoreGoal -open EcAst open EcParsetree -open EcUnify module Msym = EcSymbols.Msym @@ -26,12 +25,12 @@ let process_form_opt ?mv hyps pf oty = try let ue = unienv_of_hyps hyps in let ff = EcTyping.trans_form_opt ?mv (LDecl.toenv hyps) ue pf oty in - let ts = Tuni.subst (EcUnify.UniEnv.close ue) in + let ts = Tuni.subst ~tw_uni:(EcUnify.UniEnv.tw_assubst ue) (EcUnify.UniEnv.close ue) in EcFol.Fsubst.f_subst ts ff - with EcUnify.UninstantiateUni -> + with EcUnify.UninstanciateUni infos -> EcTyping.tyerror pf.EcLocation.pl_loc - (LDecl.toenv hyps) EcTyping.FreeTypeVariables + (LDecl.toenv hyps) (FreeUniVariables infos) (* ------------------------------------------------------------------ *) let process_form ?mv hyps pf ty = @@ -60,8 +59,10 @@ let process_type hyps pty = let ue = unienv_of_hyps hyps in let ty = EcTyping.transty EcTyping.tp_tydecl env ue pty in - if not (EcUnify.UniEnv.closed ue) then - EcTyping.tyerror (EcLocation.loc pty) env EcTyping.FreeTypeVariables; + begin match EcUnify.UniEnv.xclosed ue with + | None -> () + | Some flags -> EcTyping.tyerror (EcLocation.loc pty) env (EcTyping.FreeUniVariables flags) + end; let ts = Tuni.subst (EcUnify.UniEnv.close ue) in EcCoreSubst.ty_subst ts ty @@ -73,17 +74,17 @@ let process_stmt hyps s = let s = EcTyping.transstmt env ue s in try - let ts = Tuni.subst (EcUnify.UniEnv.close ue) in + let ts = Tuni.subst ~tw_uni:(EcUnify.UniEnv.tw_assubst ue) (EcUnify.UniEnv.close ue) in s_subst ts s - with EcUnify.UninstantiateUni -> - EcTyping.tyerror EcLocation._dummy env EcTyping.FreeTypeVariables + with EcUnify.UninstanciateUni flags -> + EcTyping.tyerror EcLocation._dummy env (EcTyping.FreeUniVariables flags) (* ------------------------------------------------------------------ *) let process_exp hyps mode oty e = let env = LDecl.toenv hyps in let ue = unienv_of_hyps hyps in let e = EcTyping.transexpcast_opt env mode ue oty e in - let ts = Tuni.subst (EcUnify.UniEnv.close ue) in + let ts = Tuni.subst ~tw_uni:(EcUnify.UniEnv.tw_assubst ue) (EcUnify.UniEnv.close ue) in e_subst ts e (* ------------------------------------------------------------------ *) @@ -166,7 +167,8 @@ let tc1_process_stmt ?map hyps tc c = let ue = unienv_of_hyps hyps in let c = Exn.recast_pe !!tc hyps (fun () -> EcTyping.transstmt ?map env ue c) in let uidmap = Exn.recast_pe !!tc hyps (fun () -> EcUnify.UniEnv.close ue) in - let es = Tuni.subst uidmap in + let tw_uni = EcUnify.UniEnv.tw_assubst ue in + let es = Tuni.subst ~tw_uni uidmap in s_subst es c @@ -228,11 +230,53 @@ let tc1_process_Xhl_formula ?side tc pf = let tc1_process_Xhl_formula_xreal tc pf = tc1_process_Xhl_form tc txreal pf + (* ------------------------------------------------------------------ *) -(* FIXME: factor out to typing module *) -(* FIXME: TC HOOK - check parameter constraints *) -(* ------------------------------------------------------------------ *) -let pf_check_tvi (pe : proofenv) (typ : EcDecl.ty_params) (tvi : tvar_inst option) = +let pf_check_tvi (env : env) (pe : proofenv) typ tvi = + let rec is_ground (ty : ty) = + match ty.ty_node with + | Tunivar _ | Tvar _ -> false + | _ -> not (ty_sub_exists (fun t -> not (is_ground t)) ty) in + + (* Walk the ancestor chain of each TC declared on an abstract type + [p] (i.e. [tyd_type = `Abstract tcs]) and accept [tc] if it + appears anywhere in [ancestors tcs(i)]. This mirrors Mode #6 of + the unifier strategies (see [strat_abs_via_decl] in ecUnify.ml). *) + let abs_satisfies (ty : ty) (tc : typeclass) = + match ty.ty_node with + | Tconstr (p, _) -> begin + match EcEnv.Ty.by_path_opt p env with + | Some { tyd_type = `Abstract tcs; _ } -> + let eq_tc tc' = + EcPath.p_equal tc.tc_name tc'.tc_name + && List.length tc.tc_args = List.length tc'.tc_args + && List.for_all2 + (fun (a, _) (b, _) -> EcCoreEqTest.for_type env a b) + tc.tc_args tc'.tc_args in + List.exists + (fun tc' -> List.exists eq_tc (EcTypeClass.ancestors env tc')) + tcs + | _ -> false + end + | _ -> false in + + (* Constraints can reference earlier tparams (e.g. 'c <: ('a, 'b) embed + references 'a, 'b). We substitute the user-supplied tparam values + before calling [infer]. *) + let check_constraints (subst : etyarg Mid.t) (tcs : typeclass list) (ty : ty) = + if is_ground ty then + List.iter (fun tc -> + let tc = EcCoreSubst.Tvar.subst_tc subst tc in + if Option.is_none (EcTypeClass.infer env ty tc) + && not (abs_satisfies ty tc) then + let ppe = EcPrinting.PPEnv.ofenv env in + tc_error_lazy pe (fun fmt -> + Format.fprintf fmt + "type @[%a@] does not satisfy typeclass constraint @[%a@]" + (EcPrinting.pp_type ppe) ty + (EcPrinting.pp_tyname ppe) tc.tc_name) + ) tcs in + match tvi with | None -> () @@ -240,15 +284,32 @@ let pf_check_tvi (pe : proofenv) (typ : EcDecl.ty_params) (tvi : tvar_inst optio if List.length tyargs <> List.length typ then tc_error pe "wrong number of type parameters (%d, expecting %d)" - (List.length tyargs) (List.length typ) + (List.length tyargs) (List.length typ); + let _ : etyarg Mid.t = + List.fold_left2 (fun subst (id, tcs) (ty_opt, _) -> + Option.iter (check_constraints subst tcs) ty_opt; + match ty_opt with + | Some ty -> Mid.add id (ty, []) subst + | None -> subst + ) Mid.empty typ tyargs + in () | Some (EcUnify.TVInamed tyargs) -> - let typnames = List.map EcIdent.name typ in + let typnames = List.map (fun (id, _) -> EcIdent.name id) typ in List.iter (fun (x, _) -> if not (List.mem x typnames) then tc_error pe "unknown type variable: %s" x) - tyargs + tyargs; + let _ : etyarg Mid.t = + List.fold_left (fun subst (id, tcs) -> + match List.assoc_opt (EcIdent.name id) tyargs with + | Some (Some ty, _) -> + check_constraints subst tcs ty; + Mid.add id (ty, []) subst + | _ -> subst + ) Mid.empty typ + in () (* -------------------------------------------------------------------- *) exception NoMatch diff --git a/src/ecProofTyping.mli b/src/ecProofTyping.mli index b5622e8b09..3a55995e46 100644 --- a/src/ecProofTyping.mli +++ b/src/ecProofTyping.mli @@ -1,11 +1,13 @@ (* -------------------------------------------------------------------- *) open EcParsetree open EcIdent +open EcAst +open EcFol open EcPath open EcDecl open EcEnv open EcCoreGoal -open EcAst +open EcMemory (* -------------------------------------------------------------------- *) type ptnenv = ty Mid.t * EcUnify.unienv @@ -15,7 +17,7 @@ type metavs = EcFol.form EcSymbols.Msym.t * proof-environment. See the [Exn] module for more information. *) val unienv_of_hyps : LDecl.hyps -> EcUnify.unienv -val pf_check_tvi : proofenv -> ty_params -> EcUnify.tvi -> unit +val pf_check_tvi : env -> proofenv -> ty_params -> EcUnify.tvi -> unit (* Typing in the environment implied by [LDecl.hyps]. *) val process_form_opt : ?mv:metavs -> LDecl.hyps -> pformula -> ty option -> form diff --git a/src/ecReduction.ml b/src/ecReduction.ml index d28cb2058c..21b14a7530 100644 --- a/src/ecReduction.ml +++ b/src/ecReduction.ml @@ -16,47 +16,15 @@ exception IncompatibleType of env * (ty * ty) exception IncompatibleForm of env * (form * form) exception IncompatibleExpr of env * (expr * expr) -(* -------------------------------------------------------------------- *) -type 'a eqtest = env -> 'a -> 'a -> bool +type 'a eqtest = env -> 'a -> 'a -> bool type 'a eqntest = env -> ?norm:bool -> 'a -> 'a -> bool type 'a eqantest = env -> ?alpha:(EcIdent.t * ty) Mid.t -> ?norm:bool -> 'a -> 'a -> bool +(* -------------------------------------------------------------------- *) module EqTest_base = struct - let rec for_type env t1 t2 = - ty_equal t1 t2 || for_type_r env t1 t2 - - and for_type_r env t1 t2 = - match t1.ty_node, t2.ty_node with - | Tunivar uid1, Tunivar uid2 -> EcUid.uid_equal uid1 uid2 - - | Tvar i1, Tvar i2 -> i1 = i2 - - | Ttuple lt1, Ttuple lt2 -> - List.length lt1 = List.length lt2 - && List.all2 (for_type env) lt1 lt2 - - | Tfun (t1, t2), Tfun (t1', t2') -> - for_type env t1 t1' && for_type env t2 t2' - - | Tglob m1, Tglob m2 -> EcIdent.id_equal m1 m2 - - | Tconstr (p1, lt1), Tconstr (p2, lt2) when EcPath.p_equal p1 p2 -> - if - List.length lt1 = List.length lt2 - && List.all2 (for_type env) lt1 lt2 - then true - else - if Ty.defined p1 env - then for_type env (Ty.unfold p1 lt1 env) (Ty.unfold p2 lt2 env) - else false - - | Tconstr(p1,lt1), _ when Ty.defined p1 env -> - for_type env (Ty.unfold p1 lt1 env) t2 - - | _, Tconstr(p2,lt2) when Ty.defined p2 env -> - for_type env t1 (Ty.unfold p2 lt2 env) - - | _, _ -> false + (* ------------------------------------------------------------------ *) + let for_type = EcCoreEqTest.for_type + let for_etyarg = EcCoreEqTest.for_etyarg (* ------------------------------------------------------------------ *) let is_unit env ty = for_type env tunit ty @@ -137,7 +105,7 @@ module EqTest_base = struct for_pv env ~norm p1 p2 | Eop(o1,ty1), Eop(o2,ty2) -> - p_equal o1 o2 && List.all2 (for_type env) ty1 ty2 + p_equal o1 o2 && List.all2 (for_etyarg env) ty1 ty2 | Equant(q1,b1,e1), Equant(q2,b2,e2) when eqt_equal q1 q2 -> let alpha = check_bindings env alpha b1 b2 in @@ -409,6 +377,9 @@ let ensure b = if b then () else raise NotConv let check_ty env subst ty1 ty2 = ensure (EqTest_base.for_type env ty1 (ty_subst subst ty2)) +let check_etyarg env subst (ty1, w1) (ty2, w2) = + ensure (EqTest_base.for_etyarg env (ty1, w1) (ty_subst subst ty2, w2)) + let add_local (env, subst) (x1, ty1) (x2, ty2) = check_ty env subst ty1 ty2; env, @@ -538,7 +509,7 @@ let is_alpha_eq ?(subst=Fsubst.f_subst_id) hyps f1 f2 = check_mod subst m1 m2 | Fop(p1, ty1), Fop(p2, ty2) when EcPath.p_equal p1 p2 -> - List.iter2 (check_ty env subst) ty1 ty2 + List.iter2 (check_etyarg env subst) ty1 ty2 | Fapp(f1',args1), Fapp(f2',args2) when List.length args1 = List.length args2 -> @@ -636,6 +607,7 @@ type reduction_info = { beta : bool; delta_p : (path -> deltap); (* reduce operators *) delta_h : (ident -> bool); (* reduce local definitions *) + delta_tc : bool; zeta : bool; iota : bool; eta : bool; @@ -652,6 +624,7 @@ let full_red = { beta = true; delta_p = (fun _ -> `IfTransparent); delta_h = EcUtils.predT; + delta_tc = true; zeta = true; iota = true; eta = true; @@ -661,15 +634,16 @@ let full_red = { } let no_red = { - beta = false; - delta_p = (fun _ -> `No); - delta_h = EcUtils.pred0; - zeta = false; - iota = false; - eta = false; - logic = None; - modpath = false; - user = false; + beta = false; + delta_p = (fun _ -> `No); + delta_h = EcUtils.pred0; + delta_tc = false; + zeta = false; + iota = false; + eta = false; + logic = None; + modpath = false; + user = false; } let beta_red = { no_red with beta = true; } @@ -677,8 +651,8 @@ let betaiota_red = { no_red with beta = true; iota = true; } let nodelta = { full_red with - delta_h = EcUtils.pred0; - delta_p = (fun _ -> `No); } + delta_h = EcUtils.pred0; + delta_p = (fun _ -> `No); } let delta = { no_red with delta_p = (fun _ -> `IfTransparent); } @@ -708,6 +682,36 @@ let reduce_op ri env nargs p tys = Op.reduce ~mode ~nargs env p tys with NotReducible -> raise nohead +(* When a TC witness is [`Abs path] and [path] resolves to a concrete + (non-abstract) type, infer the concrete instance so that the TC op + becomes reducible. This arises after cloning an abstract theory with + a [type t <: tc] carrier substituted to a concrete type. *) +let resolve_concrete_tcw (env : EcEnv.env) (p : path) (tys : etyarg list) : etyarg list = + let op = EcEnv.Op.by_path p env in + if not (EcDecl.is_tc_op op) then tys + else match List.rev tys with + | (carrier_ty, [TCIAbstract { support = `Abs ap; offset = 0; lift = [] }]) :: rest + when (match EcEnv.Ty.by_path_opt ap env with + | Some { tyd_type = `Abstract _; _ } -> false + | _ -> true) -> + let tcpath, _ = EcDecl.operator_as_tc op in + let tc_decl = EcEnv.TypeClass.by_path tcpath env in + let tc = { tc_name = tcpath; + tc_args = EcDecl.etyargs_of_tparams tc_decl.tc_tparams; } in + (match EcTypeClass.infer env carrier_ty tc with + | Some w -> List.rev ((carrier_ty, [w]) :: rest) + | None -> tys) + | _ -> tys + +let reduce_tc_op (ri : reduction_info) (env : EcEnv.env) (p : path) (tys : etyarg list) = + if ri.delta_tc then + try + Op.tc_reduce env p (resolve_concrete_tcw env p tys) + with NotReducible -> raise nohead + else + raise nohead + +(* -------------------------------------------------------------------- *) let is_record env f = match EcFol.destr_app f with | { f_node = Fop (p, _) }, _ -> EcEnv.Op.is_record_ctor env p @@ -750,8 +754,8 @@ let reduce_user_gen simplify ri env hyps f = oget ~exn:needsubterm (List.Exceptionless.find_map (fun rule -> try - let ue = EcUnify.UniEnv.create None in - let tvi = EcUnify.UniEnv.opentvi ue rule.R.rl_tyd None in + let ue = EcUnify.UniEnv.create None in + let tvi = EcUnify.UniEnv.opentvi ue rule.R.rl_tyd None in let check_alpha_eq f f' = if not (is_alpha_eq hyps f f') then raise NotReducible @@ -769,10 +773,12 @@ let reduce_user_gen simplify ri env hyps f = | ({ f_node = Fop (p, tys) }, args), R.Rule (`Op (p', tys'), args') when EcPath.p_equal p p' && List.length args = List.length args' -> - let tys' = List.map (Tvar.subst tvi) tys' in + let tys' = List.map (Tvar.subst_etyarg tvi.subst) tys' in begin - try List.iter2 (EcUnify.unify env ue) tys tys' + try + if List.length tys <> List.length tys' then raise NotReducible; + List.iter2 (EcUnify.unify_etyarg env ue) tys tys' with EcUnify.UnificationFailure _ -> raise NotReducible end; List.iter2 doit args args' @@ -804,7 +810,7 @@ let reduce_user_gen simplify ri env hyps f = let subst = ts in let subst = Mid.fold (fun x f s -> Fsubst.f_bind_local s x f) !pv subst in - Fsubst.f_subst subst (Fsubst.f_subst_tvar ~freshen:true tvi f) + Fsubst.f_subst subst (Fsubst.f_subst_tvar ~freshen:true tvi.subst f) in List.iter (fun cond -> @@ -883,7 +889,7 @@ let reduce_logic ri env hyps f p args = when EcPath.p_equal p1 p2 && EcEnv.Op.is_record_ctor env p1 && EcEnv.Op.is_record_ctor env p2 - && List.for_all2 (EqTest_i.for_type env) tys1 tys2 -> + && List.for_all2 (EqTest_i.for_etyarg env) tys1 tys2 -> f_ands (List.map2 f_eq args1 args2) @@ -904,14 +910,31 @@ let reduce_logic ri env hyps f p args = check_reduced hyps needsubterm f f' (* -------------------------------------------------------------------- *) -let reduce_delta ri env _hyps f = +let reduce_delta ri env f = match f.f_node with | Fop (p, tys) when ri.delta_p p <> `No -> - reduce_op ri env 0 p tys + reduce_op ri env 0 p tys | Fapp ({ f_node = Fop (p, tys) }, args) when ri.delta_p p <> `No -> - let op = reduce_op ri env (List.length args) p tys in - f_app_simpl op args f.f_ty + let op = reduce_op ri env (List.length args) p tys in + f_app_simpl op args f.f_ty + + | _ -> raise nohead + +(* -------------------------------------------------------------------- *) +let reduce_tc ri env f = + match f.f_node with + | Fop (p, etyargs) + when ri.delta_tc && + Op.tc_reducible env p (resolve_concrete_tcw env p etyargs) -> + reduce_tc_op ri env p etyargs + + | Fapp ({ f_node = Fop (p, etyargs) }, args) + when ri.delta_tc && + Op.tc_reducible env p (resolve_concrete_tcw env p etyargs) + -> + let op = reduce_tc_op ri env p etyargs in + f_app_simpl op args f.f_ty | _ -> raise nohead @@ -1064,7 +1087,10 @@ let reduce_head simplify ri env hyps f = let body = EcFol.form_of_expr body in (* FIXME subst-refact can we do both subst in once *) let body = - Tvar.f_subst ~freshen:true op.EcDecl.op_tparams tys body in + Tvar.f_subst ~freshen:true + (List.combine + (List.map fst op.EcDecl.op_tparams) + tys) body in f_app (Fsubst.f_subst subst body) eargs f.f_ty @@ -1081,19 +1107,22 @@ let reduce_head simplify ri env hyps f = when ri.eta && can_eta x (fn, args) -> f_app fn (List.take (List.length args - 1) args) f.f_ty - | Fop _ -> begin - try - reduce_user_gen simplify ri env hyps f - with NotRed _ -> - reduce_delta ri env hyps f - end + | Fop _ -> + oget ~exn:nohead @@ + List.find_map_opt + (fun cb -> try Some (cb f) with NotRed _ -> None) + [ reduce_user_gen simplify ri env hyps + ; reduce_delta ri env + ; reduce_tc ri env ] - | Fapp({ f_node = Fop(p,_); }, args) -> begin + | Fapp ({ f_node = Fop (p, _); }, args) -> begin try reduce_logic ri env hyps f p args with NotRed kind1 -> try reduce_user_gen simplify ri env hyps f with NotRed kind2 -> - if kind1 = NoHead && kind2 = NoHead then reduce_delta ri env hyps f + if kind1 = NoHead && kind2 = NoHead then + (try reduce_delta ri env f + with NotRed NoHead -> reduce_tc ri env f) else raise needsubterm end @@ -1195,9 +1224,18 @@ and reduce_head_top_force ri env onhead f = match reduce_head_sub ri env f with | f -> if onhead then reduce_head_top ri env ~onhead f else f - | exception (NotRed _) -> - try reduce_delta ri.ri env ri.hyps f - with NotRed _ -> RedTbl.set_norm ri.redtbl f; raise nohead + | exception (NotRed _) -> begin + match + List.find_map_opt + (fun cb -> try Some (cb ri.ri env f) with NotRed _ -> None) + [reduce_delta; reduce_tc] + with + | Some f -> + f + | None -> + RedTbl.set_norm ri.redtbl f; + raise nohead + end end and reduce_head_sub ri env f = @@ -1258,30 +1296,25 @@ let rec simplify ri env f = match f.f_node with | FhoareF hf when ri.ri.modpath -> let hf_f = EcEnv.NormMp.norm_xfun env hf.hf_f in - f_map (fun ty -> ty) (simplify ri env) - (f_hoareF (hf_pr hf) hf_f (hf_po hf)) + f_map (fun ty -> ty) (simplify ri env) (f_hoareF_r { hf with hf_f }) | FeHoareF hf when ri.ri.modpath -> let ehf_f = EcEnv.NormMp.norm_xfun env hf.ehf_f in - f_map (fun ty -> ty) (simplify ri env) - (f_eHoareF (ehf_pr hf) ehf_f (ehf_po hf)) + f_map (fun ty -> ty) (simplify ri env) (f_eHoareF_r { hf with ehf_f }) | FbdHoareF hf when ri.ri.modpath -> let bhf_f = EcEnv.NormMp.norm_xfun env hf.bhf_f in - f_map (fun ty -> ty) (simplify ri env) - (f_bdHoareF (bhf_pr hf) bhf_f (bhf_po hf) hf.bhf_cmp (bhf_bd hf)) + f_map (fun ty -> ty) (simplify ri env) (f_bdHoareF_r { hf with bhf_f }) | FequivF ef when ri.ri.modpath -> let ef_fl = EcEnv.NormMp.norm_xfun env ef.ef_fl in let ef_fr = EcEnv.NormMp.norm_xfun env ef.ef_fr in - f_map (fun ty -> ty) (simplify ri env) - (f_equivF (ef_pr ef) ef_fl ef_fr (ef_po ef)) + f_map (fun ty -> ty) (simplify ri env) (f_equivF_r { ef with ef_fl; ef_fr; }) | FeagerF eg when ri.ri.modpath -> let eg_fl = EcEnv.NormMp.norm_xfun env eg.eg_fl in let eg_fr = EcEnv.NormMp.norm_xfun env eg.eg_fr in - f_map (fun ty -> ty) (simplify ri env) - (f_eagerF (eg_pr eg) eg.eg_sl eg_fl eg_fr eg.eg_sr (eg_po eg)) + f_map (fun ty -> ty) (simplify ri env) (f_eagerF_r { eg with eg_fl ; eg_fr; }) | Fpr pr when ri.ri.modpath -> let pr_fun = EcEnv.NormMp.norm_xfun env pr.pr_fun in @@ -1410,6 +1443,9 @@ let zpop ri side f hd = let rec conv ri env f1 f2 stk = if f_equal f1 f2 then conv_next ri env f1 stk else match f1.f_node, f2.f_node with + | Flocal x, Flocal y when EcIdent.id_equal x y -> + true + | Fquant (q1, bd1, f1'), Fquant(q2,bd2,f2') -> if q1 <> q2 then force_head_sub ri env f1 f2 stk else @@ -1463,7 +1499,8 @@ let rec conv ri env f1 f2 stk = end | Fop(p1, ty1), Fop(p2,ty2) - when EcPath.p_equal p1 p2 && List.all2 (EqTest_i.for_type env) ty1 ty2 -> + when EcPath.p_equal p1 p2 + && List.all2 (EqTest_i.for_etyarg env) ty1 ty2 -> conv_next ri env f1 stk | Fapp(f1', args1), Fapp(f2', args2) @@ -1773,8 +1810,8 @@ module User = struct let rule = let rec rule (f : form) : EcTheory.rule_pattern = match EcFol.destr_app f with - | { f_node = Fop (p, tys) }, args -> - R.Rule (`Op (p, tys), List.map rule args) + | { f_node = Fop (p, etyargs) }, args -> + R.Rule (`Op (p, etyargs), List.map rule args) | { f_node = Ftuple args }, [] -> R.Rule (`Tuple, List.map rule args) | { f_node = Fproj (target, i) }, [] -> @@ -1797,12 +1834,13 @@ module User = struct | R.Rule (op, args) -> let ltyvars = match op with - | `Op (_, tys) -> - List.fold_left ( - let rec doit ltyvars = function - | { ty_node = Tvar a } -> Sid.add a ltyvars - | _ as ty -> ty_fold doit ltyvars ty in doit) - cst.cst_ty_vs tys + | `Op (_, etyargs) -> + let rec doit_ty ltyvars = function + | { ty_node = Tvar a } -> Sid.add a ltyvars + | _ as ty -> ty_fold doit_ty ltyvars ty in + List.fold_left + (fun ltyvars (ty, _) -> doit_ty ltyvars ty) + cst.cst_ty_vs etyargs | `Tuple -> cst.cst_ty_vs | `Proj _ -> cst.cst_ty_vs in let cst = {cst with cst_ty_vs = ltyvars } in @@ -1811,7 +1849,7 @@ module User = struct in doit empty_cst rule in let s_bds = Sid.of_list (List.map fst bds) - and s_tybds = Sid.of_list ax.ax_tparams in + and s_tybds = Sid.of_list (List.map fst ax.ax_tparams) in (* Variables appearing in types and formulas are always, respectively, * type and formula variables. diff --git a/src/ecReduction.mli b/src/ecReduction.mli index ceb057d245..605e9a7ae0 100644 --- a/src/ecReduction.mli +++ b/src/ecReduction.mli @@ -19,16 +19,17 @@ type 'a eqantest = env -> ?alpha:(EcIdent.t * ty) Mid.t -> ?norm:bool -> 'a -> ' module EqTest : sig val for_type_exn : env -> ty -> ty -> unit - val for_type : ty eqtest - val for_pv : prog_var eqntest - val for_lv : lvalue eqntest - val for_xp : xpath eqntest - val for_mp : mpath eqntest - val for_instr : instr eqantest - val for_stmt : stmt eqantest - val for_expr : expr eqantest - val for_msig : module_sig eqntest - val for_mexpr : env -> ?norm:bool -> ?body:bool -> module_expr -> module_expr -> bool + val for_type : ty eqtest + val for_etyarg : etyarg eqtest + val for_pv : prog_var eqntest + val for_lv : lvalue eqntest + val for_xp : xpath eqntest + val for_mp : mpath eqntest + val for_instr : instr eqantest + val for_stmt : stmt eqantest + val for_expr : expr eqantest + val for_msig : module_sig eqntest + val for_mexpr : env -> ?norm:bool -> ?body:bool -> module_expr -> module_expr -> bool val is_unit : env -> ty -> bool val is_bool : env -> ty -> bool @@ -64,6 +65,7 @@ type reduction_info = { beta : bool; delta_p : (path -> deltap); (* reduce operators *) delta_h : (ident -> bool); (* reduce local definitions *) + delta_tc : bool; (* reduce tc-operators *) zeta : bool; (* reduce let *) iota : bool; (* reduce case *) eta : bool; (* reduce eta-expansion *) diff --git a/src/ecScope.ml b/src/ecScope.ml index d0125b46fc..910a13dac8 100644 --- a/src/ecScope.ml +++ b/src/ecScope.ml @@ -305,10 +305,22 @@ and proof_state = PSNoCheck | PSCheck of EcCoreGoal.proof and pucflags = { - puc_smt : bool; - puc_local : bool; + puc_smt : bool; + puc_local : bool; } +(* -------------------------------------------------------------------- *) +type docentity = + | ItemDoc of string list * docitem + | SubDoc of (string list * docitem) * docentity list + +and docitem = + mode * itemkind * string * string list + +and itemkind = [`Type | `Operator | `Axiom | `Lemma | `ModuleType | `Module | `Theory] + +and mode = [`Abstract | `Specific] + (* -------------------------------------------------------------------- *) type required_info = { rqd_name : symbol; @@ -337,104 +349,11 @@ type scope = { sc_clears : path list; sc_pr_uc : proof_uc option; sc_options : GenOptions.options; - sc_globdoc : string list; - sc_locdoc : docstate; -} - -and docstate = { - docentities : docentity list; - subdocentbl : docentity list; - docstringbl : string list; - srcstringbl : string list; - currentname : string option; - currentkind : itemkind option; - currentmode : mode option; - currentproc : bool; } -and docentity = - | ItemDoc of string list * docitem - | SubDoc of (string list * docitem) * docentity list - -and docitem = - mode * itemkind * string * string list (* dec/reg, kind, name, src *) - -and itemkind = [`Type | `Operator | `Axiom | `Lemma | `ModuleType | `Module | `Theory] - -and mode = [`Abstract | `Specific] - (* -------------------------------------------------------------------- *) -let get_gdocstrings (sc : scope) : string list = - sc.sc_globdoc - -let get_ldocentities (sc : scope) : docentity list = - sc.sc_locdoc.docentities - -module DocState = struct - let empty : docstate = - { docentities = []; - subdocentbl = []; - docstringbl = []; - srcstringbl = []; - currentname = None; - currentkind = None; - currentmode = None; - currentproc = false; } - - let start_process (state : docstate) (name : string) (kind : itemkind) (md : mode): docstate = - { state with - currentname = Some name; - currentkind = Some kind; - currentmode = Some md; - currentproc = true } - - let prevent_process (state : docstate) : docstate = - { state with - currentname = None; - currentkind = None; - currentmode = None; - currentproc = false } - - let reinitialize_process (state : docstate) : docstate = - { state with - docstringbl = []; - srcstringbl = []; - currentname = None; - currentkind = None; - currentmode = None; - currentproc = false } - - let push_docbl (state : docstate) (docc : string) : docstate = - { state with docstringbl = state.docstringbl @ [docc] } - - let push_srcbl (state : docstate) (srcs : string) : docstate = - { state with srcstringbl = state.srcstringbl @ [srcs] } - - let add_entity (state : docstate) (docent : docentity) : docstate = - { state with docentities = state.docentities @ [docent] } - - let add_item (state : docstate) : docstate = - let state = - if state.currentproc - then - add_entity state (ItemDoc (state.docstringbl, (oget state.currentmode, oget state.currentkind, oget state.currentname, state.srcstringbl))) - else - state - in - reinitialize_process state - - let add_sub (state : docstate) (substate : docstate) : docstate = - let state = - if state.currentproc - then - add_entity state (SubDoc ((state.docstringbl, (oget state.currentmode, oget state.currentkind, oget state.currentname, state.srcstringbl)), - (substate.docentities))) - else - state - in - reinitialize_process state - - end +let get_gdocstrings (_ : scope) : string list = [] +let get_ldocentities (_ : scope) : docentity list = [] (* -------------------------------------------------------------------- *) let empty (gstate : EcGState.gstate) = @@ -447,9 +366,7 @@ let empty (gstate : EcGState.gstate) = sc_required = []; sc_clears = []; sc_pr_uc = None; - sc_options = GenOptions.freeze (); - sc_globdoc = []; - sc_locdoc = DocState.empty; } + sc_options = GenOptions.freeze (); } (* -------------------------------------------------------------------- *) let env (scope : scope) = @@ -569,8 +486,7 @@ let for_loading (scope : scope) = sc_clears = []; sc_pr_uc = None; sc_options = GenOptions.for_loading scope.sc_options; - sc_globdoc = []; - sc_locdoc = DocState.empty; } + } (* -------------------------------------------------------------------- *) let subscope (scope : scope) (mode : EcTheory.thmode) (name : symbol) lc = @@ -584,10 +500,7 @@ let subscope (scope : scope) (mode : EcTheory.thmode) (name : symbol) lc = sc_required = scope.sc_required; sc_clears = []; sc_pr_uc = None; - sc_options = GenOptions.for_subscope scope.sc_options; - sc_globdoc = []; - sc_locdoc = DocState.empty; - } + sc_options = GenOptions.for_subscope scope.sc_options; } (* -------------------------------------------------------------------- *) module Prover = struct @@ -797,7 +710,7 @@ module Tactics = struct let pi scope pi = Prover.do_prover_info scope pi - let proof ?(src : string option) (scope : scope) = + let proof ?src:_ (scope : scope) = check_state `InActiveProof "proof script" scope; match (oget scope.sc_pr_uc).puc_active with @@ -808,14 +721,10 @@ module Tactics = struct hierror "[proof] can only be used at beginning of a proof script"; { pac with puc_started = true } in - { scope with - sc_pr_uc = Some { (oget scope.sc_pr_uc) with puc_active = Some (pac, pct) }; - sc_locdoc = - match src with - | Some src -> DocState.push_srcbl scope.sc_locdoc src - | None -> scope.sc_locdoc; } + { scope with sc_pr_uc = + Some { (oget scope.sc_pr_uc) with puc_active = Some (pac, pct); } } - let process_r ?(src : string option) ?reloc mark (mode : proofmode) (scope : scope) (tac : ptactic list) = + let process_r ?reloc mark (mode : proofmode) (scope : scope) (tac : ptactic list) = check_state `InProof "proof script" scope; let scope = @@ -827,13 +736,6 @@ module Tactics = struct else scope in - let scope = { scope with - sc_locdoc = - match src with - | Some src -> DocState.push_srcbl scope.sc_locdoc src - | None -> scope.sc_locdoc; } - in - let puc = oget (scope.sc_pr_uc) in let pac, pct = oget (puc).puc_active in @@ -874,7 +776,7 @@ module Tactics = struct let pac = { pac with puc_jdg = PSCheck juc } in let puc = { puc with puc_active = Some (pac, pct); } in - let scope = { scope with sc_pr_uc = Some puc; } in + let scope = { scope with sc_pr_uc = Some puc } in Some (penv, hds), scope let process1_r mark mode scope t = @@ -884,8 +786,8 @@ module Tactics = struct let ts = List.map (fun t -> { pt_core = t; pt_intros = []; }) ts in snd (process_r mark mode scope ts) - let process ?(src : string option) scope mode tac = - process_r ?src true mode scope tac + let process ?src:_ scope mode tac = + process_r true mode scope tac end (* -------------------------------------------------------------------- *) @@ -911,7 +813,7 @@ module Auto = struct { scope with sc_env = EcSection.add_item item scope.sc_env } let bind_hint scope ~local ~level ?base axioms = - let item = EcTheory.mkitem ~import:true (Th_auto { level; base; axioms; locality=local} ) in + let item = EcTheory.mkitem ~import:true (Th_auto { level; base; axioms; locality = local; }) in { scope with sc_env = EcSection.add_item item scope.sc_env } let add_hint scope hint = @@ -939,9 +841,7 @@ module Ax = struct let bind ?(import = true) (scope : scope) ((x, ax) : _ * axiom) = assert (scope.sc_pr_uc = None); let item = EcTheory.mkitem ~import (EcTheory.Th_axiom (x, ax)) in - { scope with sc_env = - EcSection.add_item item scope.sc_env; - sc_locdoc = DocState.add_item scope.sc_locdoc; } + { scope with sc_env = EcSection.add_item item scope.sc_env } (* ------------------------------------------------------------------ *) let start_lemma scope (cont, axflags) check ?name (axd, ctxt) = @@ -992,11 +892,15 @@ module Ax = struct let concl = TT.trans_prop env ue pconcl in - if not (EcUnify.UniEnv.closed ue) then - hierror "the formula contains free type variables"; + Option.iter (fun infos -> + hierror + "the formula contains free %a variables" + EcUserMessages.TypingError.pp_uniflags infos + ) (EcUnify.UniEnv.xclosed ue); let uidmap = EcUnify.UniEnv.close ue in - let fs = Tuni.subst uidmap in + let tw_uni = EcUnify.UniEnv.tw_assubst ue in + let fs = Tuni.subst ~tw_uni uidmap in let concl = Fsubst.f_subst fs concl in let tparams = EcUnify.UniEnv.tparams ue in @@ -1006,11 +910,11 @@ module Ax = struct | PAxiom tags -> `Axiom (Ssym.of_list (List.map unloc tags), false) | _ -> `Lemma - in { ax_tparams = tparams; - ax_spec = concl; - ax_kind = kind; - ax_loca = ax.pa_locality; - ax_smt = true; } + in { ax_tparams = tparams; + ax_spec = concl; + ax_kind = kind; + ax_loca = ax.pa_locality; + ax_smt = true; } in match ax.pa_kind with @@ -1133,69 +1037,22 @@ module Ax = struct save_r scope (* ------------------------------------------------------------------ *) - let save ?(src : string option) scope = + let save ?src:_ scope = check_state `InProof "save" scope; - - let scope = - { scope with - sc_locdoc = - match src with - | Some src -> DocState.push_srcbl scope.sc_locdoc src - | None -> scope.sc_locdoc; } - in save_r ~mode:`Save scope (* ------------------------------------------------------------------ *) - let admit ?(src : string option) scope = + let admit ?src:_ scope = check_state `InProof "admitted" scope; - - let scope = - { scope with - sc_locdoc = - match src with - | Some src -> DocState.push_srcbl scope.sc_locdoc src - | None -> scope.sc_locdoc; } - in - save_r ~mode:`Admit scope (* ------------------------------------------------------------------ *) - let abort ?(src : string option) scope = + let abort ?src:_ scope = check_state `InProof "abort" scope; - - let scope = - { scope with - sc_locdoc = - match src with - | Some src -> DocState.push_srcbl scope.sc_locdoc src - | None -> scope.sc_locdoc; } - in - snd (save_r ~mode:`Abort scope) (* ------------------------------------------------------------------ *) - let add ?(src : string option) (scope : scope) (mode : proofmode) (ax : paxiom located) = - let uax = unloc ax in - let kind = - match uax.pa_kind with - | PLemma _ -> `Lemma - | _ -> `Axiom - in - let scope = - { scope with - sc_locdoc = - match uax.pa_locality with - | `Local -> DocState.prevent_process scope.sc_locdoc - | `Global -> DocState.start_process scope.sc_locdoc (unloc uax.pa_name) kind `Specific - | `Declare -> DocState.start_process scope.sc_locdoc (unloc uax.pa_name) kind `Abstract} - in - let scope = - { scope with - sc_locdoc = - match src with - | Some src -> DocState.push_srcbl scope.sc_locdoc src - | None -> scope.sc_locdoc; } - in + let add ?src:_ (scope : scope) (mode : proofmode) (ax : paxiom located) = add_r scope mode ax (* ------------------------------------------------------------------ *) @@ -1251,33 +1108,41 @@ module Op = struct let bind ?(import = true) (scope : scope) ((x, op) : _ * operator) = assert (scope.sc_pr_uc = None); let item = EcTheory.mkitem ~import (EcTheory.Th_operator (x, op)) in - { scope with sc_env = - EcSection.add_item item scope.sc_env; - sc_locdoc = DocState.add_item scope.sc_locdoc; } - - let add ?(src : string option) (scope : scope) (op : poperator located) = - assert (scope.sc_pr_uc = None); + { scope with sc_env = EcSection.add_item item scope.sc_env; } - let uop = unloc op in - let scope = - { scope with - sc_locdoc = - match uop.po_locality with - | `Local -> DocState.prevent_process scope.sc_locdoc - | `Global -> DocState.start_process scope.sc_locdoc (unloc uop.po_name) `Operator `Specific - | `Declare -> DocState.start_process scope.sc_locdoc (unloc uop.po_name) `Operator `Abstract } + (* -------------------------------------------------------------------- *) + let axiomatized_op ?(nargs = 0) ?(nosmt = false) path (tparams, axbd) lc = + let axpm, axbd = + let subst, axpm = EcSubst.fresh_tparams EcSubst.empty tparams in + (axpm, EcSubst.subst_form subst axbd) in - let scope = - { scope with - sc_locdoc = - match src with - | Some src -> DocState.push_srcbl scope.sc_locdoc src - | None -> scope.sc_locdoc; } + + let args, axbd = + match axbd.f_node with + | Fquant (Llambda, bds, axbd) -> + let bds, flam = List.split_at nargs bds in + (bds, f_lambda flam axbd) + | _ -> [], axbd in + let opargs = List.map (fun (x, ty) -> f_local x (gty_as_ty ty)) args in + let opty = toarrow (List.map f_ty opargs) axbd.EcAst.f_ty in + let op = f_op_tc path (etyargs_of_tparams axpm) opty in + let op = f_app op opargs axbd.f_ty in + let axspec = f_forall args (f_eq op axbd) in + + { ax_tparams = axpm; + ax_spec = axspec; + ax_kind = `Axiom (Ssym.empty, false); + ax_loca = lc; + ax_smt = if nosmt then false else true; } + + let add ?src:_ (scope : scope) (op : poperator located) = + assert (scope.sc_pr_uc = None); let op = op.pl_desc and loc = op.pl_loc in let eenv = env scope in let ue = TT.transtyvars eenv (loc, op.po_tyvars) in + let lc = op.po_locality in let args = fst op.po_args @ odfl [] (snd op.po_args) in let (ty, body, refts) = @@ -1312,11 +1177,15 @@ module Op = struct (opty, `Abstract, [(rname, xs, reft, codom)]) in - if not (EcUnify.UniEnv.closed ue) then - hierror ~loc "this operator type contains free type variables"; + Option.iter (fun infos -> + hierror ~loc + "this operator type contains free %a variables" + EcUserMessages.TypingError.pp_uniflags infos + ) (EcUnify.UniEnv.xclosed ue); let uidmap = EcUnify.UniEnv.close ue in - let ts = Tuni.subst uidmap in + let tw_uni = EcUnify.UniEnv.tw_assubst ue in + let ts = Tuni.subst ~tw_uni uidmap in let fs = Fsubst.f_subst ts in let ty = ty_subst ts ty in let tparams = EcUnify.UniEnv.tparams ue in @@ -1326,7 +1195,7 @@ module Op = struct | `Plain e -> Some (OP_Plain (fs e)) | `Fix opfx -> Some (OP_Fix { - opf_recp = EcPath.pqname (EcEnv.root eenv) (EcIdent.name opfx.EHI.mf_name); + opf_recp = EcPath.psymbol "_"; opf_args = opfx.EHI.mf_args; opf_resty = opfx.EHI.mf_codom; opf_struct = (opfx.EHI.mf_recs, List.length opfx.EHI.mf_args); @@ -1357,7 +1226,7 @@ module Op = struct try EcUnify.unify eenv tue ty tfun; - let msg = "this operator type is (unifiable) to a function type" in + let msg = "this operator type is (unifiable to) a function type" in hierror ~loc "%s" msg with EcUnify.UnificationFailure _ -> () end; @@ -1370,8 +1239,8 @@ module Op = struct | OB_oper (Some (OP_Plain bd)) -> let path = EcPath.pqname (path scope) (unloc op.po_name) in let axop = - let nargs = List.sum (List.map (List.length -| fst) args) in - EcDecl.axiomatized_op ~nargs path (tyop.op_tparams, bd) lc in + let nargs = List.sum (List.map (fst |- List.length) args) in + axiomatized_op ~nargs path (tyop.op_tparams, bd) lc in let tyop = { tyop with op_opaque = { reduction = true; smt = false; }} in let scope = bind scope (unloc op.po_name, tyop) in Ax.bind scope (unloc ax, axop) @@ -1384,7 +1253,7 @@ module Op = struct List.fold_left (fun scope (rname, xs, ax, codom) -> let ax = let opargs = List.map (fun (x, xty) -> e_local x xty) xs in - let opapp = List.map tvar tparams in + let opapp = List.map (fst |- tvar) tparams in let opapp = e_app (e_op opname opapp ty) opargs codom in let subst = EcSubst.add_opdef EcSubst.empty opname ([], opapp) in @@ -1392,23 +1261,23 @@ module Op = struct let ax = f_forall (List.map (snd_map gtty) xs) ax in let uidmap = EcUnify.UniEnv.close ue in - let subst = Tuni.subst uidmap in + let tw_uni = EcUnify.UniEnv.tw_assubst ue in + let subst = Tuni.subst ~tw_uni uidmap in let ax = Fsubst.f_subst subst ax in ax in - let ax, axpm = - let bdpm = tparams in - let axpm = List.map EcIdent.fresh bdpm in - (Tvar.f_subst ~freshen:true bdpm (List.map EcTypes.tvar axpm) ax, - axpm) in + let axpm, ax = + let subst, tparams = EcSubst.fresh_tparams EcSubst.empty tparams in + (tparams, EcSubst.subst_form subst ax) in + let ax = - { ax_tparams = axpm; - ax_spec = ax; - ax_kind = `Axiom (Ssym.empty, false); - ax_loca = lc; - ax_smt = true; } + { ax_tparams = axpm; + ax_spec = ax; + ax_kind = `Axiom (Ssym.empty, false); + ax_loca = lc; + ax_smt = true; } in Ax.bind scope (unloc rname, ax)) scope refts in @@ -1419,11 +1288,11 @@ module Op = struct hierror ~loc "multiple names are only allowed for non-refined abstract operators"; let addnew scope name = - let nparams = List.map EcIdent.fresh tparams in - let subst = Tvar.init - tparams - (List.map tvar nparams) in - let rop = EcDecl.mk_op ~opaque:optransparent nparams (Tvar.subst subst ty) None lc in + let subst, nparams = + EcSubst.fresh_tparams EcSubst.empty tparams in + let rop = + EcDecl.mk_op ~opaque:optransparent + nparams (EcSubst.subst_ty subst ty) None lc in bind scope (unloc name, rop) in List.fold_left addnew scope op.po_aliases @@ -1438,10 +1307,18 @@ module Op = struct if not (EcAlgTactic.is_module_loaded (env scope)) then hierror "for tag %s, load Distr first" tag; - let oppath = EcPath.pqname (path scope) (unloc op.po_name) in - let nparams = List.map EcIdent.fresh tyop.op_tparams in - let subst = Tvar.init tyop.op_tparams (List.map tvar nparams) in - let ty = Tvar.subst subst tyop.op_ty in + let subst, nparams = + EcSubst.fresh_tparams EcSubst.empty tyop.op_tparams in + let oppath = EcPath.pqname (path scope) (unloc op.po_name) in + let optyargs = + let mktcw (a : EcIdent.t) (i : int) = + TCIAbstract { support = `Var a; offset = i; lift = [] } + in + List.map + (fun (a, tcs) -> (tvar a, List.mapi (fun i _ -> mktcw a i) tcs)) + nparams + in + let ty = EcSubst.subst_ty subst tyop.op_ty in let aty, rty = EcTypes.tyfun_flat ty in let dty = @@ -1451,17 +1328,17 @@ module Op = struct in let bds = List.combine (List.map EcTypes.fresh_id_of_ty aty) aty in - let ax = EcFol.f_op oppath (List.map tvar nparams) ty in + let ax = EcFol.f_op_tc oppath optyargs ty in let ax = EcFol.f_app ax (List.map (curry f_local) bds) rty in let ax = EcFol.f_app (EcFol.f_op pred [dty] (tfun rty tbool)) [ax] tbool in let ax = EcFol.f_forall (List.map (snd_map gtty) bds) ax in let ax = - { ax_tparams = nparams; - ax_spec = ax; - ax_kind = `Axiom (Ssym.empty, false); - ax_loca = lc; - ax_smt = true; } in + { ax_tparams = nparams; + ax_spec = ax; + ax_kind = `Axiom (Ssym.empty, false); + ax_loca = lc; + ax_smt = true; } in let scope, axname = let axname = Printf.sprintf "%s_%s" (unloc op.po_name) suffix in @@ -1473,9 +1350,7 @@ module Op = struct List.fold_left (fun scope base -> - Auto.bind_hint - ~local:(local_of_locality lc) ~level:0 ~base scope - [(axpath, `Default)]) + Auto.bind_hint ~local:(local_of_locality lc) ~level:0 ~base scope [(axpath, `Default)]) scope bases in @@ -1501,26 +1376,9 @@ module Op = struct tyop, List.rev !axs, scope - let add_opsem ?(src : string option) (scope : scope) (op : pprocop located) = + let add_opsem ?src:_ (scope : scope) (op : pprocop located) = let module Sem = EcProcSem in - let uop = unloc op in - let scope = - { scope with - sc_locdoc = - match uop.ppo_locality with - | `Local -> DocState.prevent_process scope.sc_locdoc - | `Global -> DocState.start_process scope.sc_locdoc (unloc uop.ppo_name) `Operator `Specific - | `Declare -> DocState.start_process scope.sc_locdoc (unloc uop.ppo_name) `Operator `Abstract } - in - let scope = - { scope with - sc_locdoc = - match src with - | Some src -> DocState.push_srcbl scope.sc_locdoc src - | None -> scope.sc_locdoc; } - in - let op = unloc op in let f = EcTyping.trans_gamepath (env scope) op.ppo_target in let sig_, body = @@ -1546,7 +1404,9 @@ module Op = struct (`Det, Sem.translate_e env ret) in let mode, aout = Sem.translate_s env cont body.f_body in - let aout = form_of_expr aout in (* FIXME: translate to forms directly? *) + let aout = + let m = EcIdent.create "&hr" in + form_of_expr ~m aout in (* FIXME: translate to forms directly? *) let aout = f_lambda (List.map2 (fun (_, ty) x -> (x, GTty ty)) params ids) aout in let opdecl = EcDecl.{ @@ -1565,11 +1425,12 @@ module Op = struct let scope = let prax = + let m = EcIdent.create "&hr" in let locs = List.map (fun (x, ty) -> (EcIdent.create x, ty)) params in - let prmem = EcIdent.create "&m" in - let res = f_pvar pv_res sig_.fs_ret prmem in + let res = f_pvar pv_res sig_.fs_ret m in let resx = EcIdent.create "v" in let resv = f_local resx sig_.fs_ret in + let prmem = EcIdent.create "&m" in let mu = let sem = @@ -1594,16 +1455,16 @@ module Op = struct (f_pr prmem f (f_tuple (List.map (fun (x, ty) -> f_local x ty) locs)) - (map_ss_inv1 (fun r -> f_eq r resv) res)) + { m; inv = f_eq res.inv resv }) mu)) in let prax = EcDecl.{ - ax_tparams = []; - ax_spec = prax; - ax_kind = `Lemma; - ax_loca = op.ppo_locality; - ax_smt = true; + ax_tparams = []; + ax_spec = prax; + ax_kind = `Lemma; + ax_loca = op.ppo_locality; + ax_smt = true; } in Ax.bind scope (unloc op.ppo_name ^ "_opsem", prax) in @@ -1612,35 +1473,32 @@ module Op = struct match mode with | `Det -> let hax = + let m = EcIdent.create "&hr" in let locs = List.map (fun (x, ty) -> (EcIdent.create x, ty)) params in - let m = EcIdent.create "&hr" in let res = f_pvar pv_res sig_.fs_ret m in let args = f_pvar pv_arg sig_.fs_arg m in - let post = - f_eq - res.inv - (f_app - (f_op oppath [] opdecl.op_ty) - (List.map (fun (x, ty) -> f_local x ty) locs) - sig_.fs_ret) - in f_forall (List.map (fun (x, ty) -> (x, GTty ty)) locs) (f_hoareF - {m;inv=(f_eq + { m; inv = f_eq args.inv - (f_tuple (List.map (fun (x, ty) -> f_local x ty) locs)))} + (f_tuple (List.map (fun (x, ty) -> f_local x ty) locs)) } f - {hsi_m=m;hsi_inv= POE.empty post}) + (POE.lift { m; inv = f_eq + res.inv + (f_app + (f_op oppath [] opdecl.op_ty) + (List.map (fun (x, ty) -> f_local x ty) locs) + sig_.fs_ret) })) in let prax = EcDecl.{ - ax_tparams = []; - ax_spec = hax; - ax_kind = `Lemma; - ax_loca = op.ppo_locality; - ax_smt = true; + ax_tparams = []; + ax_spec = hax; + ax_kind = `Lemma; + ax_loca = op.ppo_locality; + ax_smt = true; } in Ax.bind scope (unloc op.ppo_name ^ "_opsem_det", prax) @@ -1656,11 +1514,6 @@ end module Exception = struct module TT = EcTyping - let bind ?(import = true) (scope : scope) (x, e) = - assert (scope.sc_pr_uc = None); - let op = operator_of_exception e in - Op.bind ~import scope (x, op) - let add (scope : scope) (pe : pexception_decl located) = assert (scope.sc_pr_uc = None); let loc = loc pe in @@ -1673,7 +1526,8 @@ module Exception = struct if tparams <> [] then hierror ~loc "Polymorphic expression are not allowed"; let e = EcDecl.mk_exception lc e_dom in - let scope = bind scope (unloc pe.pe_name, e) in + let op = EcDecl.operator_of_exception e in + let scope = Op.bind scope (unloc pe.pe_name, op) in e, scope end @@ -1681,26 +1535,9 @@ end module Pred = struct module TT = EcTyping - let add ?(src : string option) (scope : scope) (pr : ppredicate located) = + let add ?src:_ (scope : scope) (pr : ppredicate located) = assert (scope.sc_pr_uc = None); - let upr = unloc pr in - let scope = - { scope with - sc_locdoc = - match upr.pp_locality with - | `Local -> DocState.prevent_process scope.sc_locdoc - | `Global -> DocState.start_process scope.sc_locdoc (unloc upr.pp_name) `Operator `Specific - | `Declare -> DocState.start_process scope.sc_locdoc (unloc upr.pp_name) `Operator `Abstract } - in - let scope = - { scope with - sc_locdoc = - match src with - | Some src -> DocState.push_srcbl scope.sc_locdoc src - | None -> scope.sc_locdoc; } - in - let typr = EcHiPredicates.trans_preddecl (env scope) pr in let scope = Op.bind scope (unloc (unloc pr).pp_name, typr) in typr, scope @@ -1725,34 +1562,14 @@ module Mod = struct let bind ?(import = true) (scope : scope) (m : top_module_expr) = assert (scope.sc_pr_uc = None); let item = EcTheory.mkitem ~import (EcTheory.Th_module m) in - { scope with - sc_env = EcSection.add_item item scope.sc_env; - sc_locdoc = DocState.add_item scope.sc_locdoc; } + { scope with sc_env = EcSection.add_item item scope.sc_env } - let add_concrete ?(src : string option) (scope : scope) lc (ptm : pmodule_def) = + let add_concrete (scope : scope) lc (ptm : pmodule_def) = assert (scope.sc_pr_uc = None); if lc = `Declare then hierror "cannot use [declare] for concrete modules"; - let nm = unloc (EcParsetree.pcmhd_ident ptm.ptm_header) in - - let scope = - { scope with - sc_locdoc = - match lc with - | `Local -> DocState.prevent_process scope.sc_locdoc - | `Global -> DocState.start_process scope.sc_locdoc nm `Module `Specific - | `Declare -> DocState.start_process scope.sc_locdoc nm `Module `Abstract } - in - let scope = - { scope with - sc_locdoc = - match src with - | Some src -> DocState.push_srcbl scope.sc_locdoc src - | None -> scope.sc_locdoc; } - in - let m = TT.transmod (env scope) ~attop:true ptm in let ur = EcModules.get_uninit_read_of_module (path scope) m in @@ -1782,10 +1599,10 @@ module Mod = struct { scope with sc_env = EcSection.add_decl_mod name tysig scope.sc_env } - let add ?(src : string option) (scope : scope) (m : pmodule_def_or_decl) = + let add ?src:_ (scope : scope) (m : pmodule_def_or_decl) = match m with | { ptm_locality = lc; ptm_def = `Concrete def } -> - add_concrete ?src scope lc def + add_concrete scope lc def | { ptm_locality = lc; ptm_def = `Abstract decl } -> if lc <> `Declare then @@ -1806,170 +1623,1133 @@ module ModType = struct = assert (scope.sc_pr_uc = None); let item = EcTheory.mkitem ~import (EcTheory.Th_modtype (x, tysig)) in - { scope with - sc_env = EcSection.add_item item scope.sc_env; - sc_locdoc = DocState.add_item scope.sc_locdoc; } + { scope with sc_env = EcSection.add_item item scope.sc_env } - let add ?(src : string option) (scope : scope) (intf : pinterface) = + let add ?src:_ (scope : scope) (intf : pinterface) = assert (scope.sc_pr_uc = None); - - let scope = - { scope with - sc_locdoc = - match intf.pi_locality with - | `Local -> DocState.prevent_process scope.sc_locdoc - | `Global -> DocState.start_process scope.sc_locdoc (unloc intf.pi_name) `ModuleType `Specific } - in - let scope = - { scope with - sc_locdoc = - match src with - | Some src -> DocState.push_srcbl scope.sc_locdoc src - | None -> scope.sc_locdoc; } - in let tysig = EcTyping.transmodsig (env scope) intf in bind scope (unloc intf.pi_name, tysig) end (* -------------------------------------------------------------------- *) -module Theory = struct - open EcTheory +(* Forward reference: filled in later by [Cloning] (which depends on + [Theory] which is defined after [Ty]). *) +let subtype_hooks_ref : scope EcTheoryReplay.ovrhooks ref = + ref { EcTheoryReplay.henv = (fun _ -> assert false); + EcTheoryReplay.hadd_item = (fun _ ~import:_ _ -> assert false); + EcTheoryReplay.hthenter = (fun _ _ _ _ -> assert false); + EcTheoryReplay.hthexit = (fun _ ~import:_ _ -> assert false); + EcTheoryReplay.herr = (fun ?loc:_ _ -> assert false); } - exception TopScope +(* -------------------------------------------------------------------- *) +module Ty = struct + open EcDecl + open EcTyping - (* ------------------------------------------------------------------ *) - let bind ?(import = true) (scope : scope) (cth : thloaded) = - assert (scope.sc_pr_uc = None); - { scope with - sc_env = EcSection.add_th ~import cth scope.sc_env } + module TT = EcTyping + module ELI = EcInductive + module EHI = EcHiInductive (* ------------------------------------------------------------------ *) - let required (scope : scope) (rqd : required_info) = - assert (scope.sc_pr_uc = None); - List.exists (fun x -> - if x.rqd_name = rqd.rqd_name then ( - if (x.rqd_digest <> rqd.rqd_digest) then begin - let fullname (ri : required_info) = - let namespace = - ri.rqd_namespace - |> Option.map EcLoader.string_of_namespace - |> Option.map (fun s -> s ^ ":") - |> Option.value ~default:"" in - namespace ^ ri.rqd_name in - hierror - "Digest mismatch, file %s differs from %s" - (fullname x) (fullname rqd) - end; - true) - else false) - scope.sc_required + let check_name_available scope x = + let pname = EcPath.pqname (EcEnv.root (env scope)) x.pl_desc in - (* ------------------------------------------------------------------ *) - let mark_as_direct (scope : scope) (name : symbol) = - let for1 rq = - if rq.rqd_name = name - then { rq with rqd_direct = true } - else rq - in { scope with sc_required = List.map for1 scope.sc_required } + if EcEnv.Ty .by_path_opt pname (env scope) <> None + || EcEnv.TypeClass.by_path_opt pname (env scope) <> None then + hierror ~loc:x.pl_loc "duplicated type/type-class name `%s'" x.pl_desc (* ------------------------------------------------------------------ *) - let enter ?(src : string option) (scope : scope) (mode : thmode) (name : symbol) = + let bind ?(import = true) (scope : scope) ((x, tydecl) : (_ * tydecl)) = assert (scope.sc_pr_uc = None); - let sc_locdoc = scope.sc_locdoc in - let sc_locdoc = - match src with - | None -> DocState.prevent_process scope.sc_locdoc - | Some src -> - let sc_locdoc = - DocState.start_process sc_locdoc name `Theory - (match mode with `Concrete -> `Specific | `Abstract -> `Abstract) - in - DocState.push_srcbl sc_locdoc src - in - let - scope = { scope with sc_locdoc } - in - - subscope scope mode name + let item = EcTheory.mkitem ~import (EcTheory.Th_type (x, tydecl)) in + { scope with sc_env = EcSection.add_item item scope.sc_env } (* ------------------------------------------------------------------ *) - let rec require_loaded (id : required_info) scope = - if required scope id then - scope - else - match Msym.find_opt id.rqd_name scope.sc_loaded with - | Some (rth, ids) -> - let scope = List.fold_right require_loaded ids scope in - let env = EcSection.require rth scope.sc_env in - { scope with - sc_env = env; - sc_required = id :: scope.sc_required; } + let add_subtype (scope : scope) ({ pl_desc = subtype } : psubtype located) = + let loced x = mk_loc _dummy x in + let env = env scope in - | None -> assert false + let carrier = + let ue = EcUnify.UniEnv.create None in + transty tp_tydecl env ue subtype.pst_carrier in - (* ------------------------------------------------------------------ *) - let update_with_required ~(dst : scope) ~(src : scope) = - let dst = - let sc_loaded = - Msym.union - (fun _ x y -> assert (x ==(*phy*) y); Some x) - dst.sc_loaded src.sc_loaded - in { dst with sc_loaded } - in List.fold_right require_loaded src.sc_required dst + let pred = + let x = EcIdent.create (fst subtype.pst_pred).pl_desc in + let env = EcEnv.Var.bind_local x carrier env in + let ue = EcUnify.UniEnv.create None in + let pred = EcTyping.trans_prop env ue (snd subtype.pst_pred) in + if not (EcUnify.UniEnv.closed ue) then + hierror ~loc:(snd subtype.pst_pred).pl_loc + "the predicate contains free type variables"; + let uidmap = EcUnify.UniEnv.close ue in + let tw_uni = EcUnify.UniEnv.tw_assubst ue in + let fs = EcCoreSubst.Tuni.subst ~tw_uni uidmap in + f_lambda [(x, GTty carrier)] (Fsubst.f_subst fs pred) in - (* ------------------------------------------------------------------ *) - let add_clears clears scope = - let clears = - let for1 = function - | None -> EcEnv.root (env scope) - | Some { pl_loc = loc; pl_desc = (xs, x) as q } -> - let xp = EcEnv.root (env scope) in - let xp = EcPath.pqname (EcPath.extend xp xs) x in - if is_none (EcEnv.Theory.by_path_opt xp (env scope)) then - hierror ~loc "unknown theory: `%s`" (string_of_qsymbol q); - xp - in List.map for1 clears - in { scope with sc_clears = scope.sc_clears @ clears } + let scope = + let decl = EcDecl.{ + tyd_params = []; + tyd_type = `Abstract []; + tyd_resolve = true; + tyd_loca = `Global; + (* Carry the carrier+predicate so [tydecl_fv] picks up the + dependency on section-declared types and [generalize_tydecl] + produces the right tparams at section close. *) + tyd_subtype = Some (carrier, pred); + } in bind scope (unloc subtype.pst_name, decl) in - (* -------------------------------------------------------------------- *) - let exit_r ?pempty (scope : scope) = - match scope.sc_top with - | None -> raise TopScope - | Some sup -> - let clears = scope.sc_clears in - let _, cth, _ = EcSection.exit_theory ?pempty ~clears scope.sc_env in - let loaded = scope.sc_loaded in - let required = scope.sc_required in - let sup = { - sup with - sc_loaded = loaded; - sc_locdoc = DocState.add_sub sup.sc_locdoc scope.sc_locdoc} in - ((cth, required), scope.sc_name, sup) + let evclone : EcThCloning.evclone = + let t_entry : EcThCloning.xty_override = (`Direct carrier, `Inline `Clear) in + let st_entry : EcThCloning.xty_override = + ((`ByPath + (EcPath.pqname (EcEnv.root env) (unloc subtype.pst_name)) + :> [`ByPath of EcPath.path | `BySyntax of EcParsetree.ty_override_def | `Direct of EcAst.ty]), + `Inline `Clear) in + let p_entry : EcThCloning.xop_override = (`Direct pred, `Inline `Clear) in + { EcThCloning.evc_empty with + evc_types = Msym.of_list [ + "T", loced t_entry; + "sT", loced st_entry; + ]; + evc_ops = Msym.of_list [ + "P", loced p_entry; + ]; + evc_lemmas = { + ev_bynames = Msym.empty; + ev_global = [ (None, Some [`Include, "prove"]) ] + } } in - (* ------------------------------------------------------------------ *) - let exit ?import ?(pempty = `ClearOnly) ?(clears =[]) (scope : scope) = - assert (scope.sc_pr_uc = None); + let cname = Option.map unloc subtype.pst_cname in + let npath = ofold ((^~) EcPath.pqname) (EcEnv.root env) cname in + let cpath = EcPath.fromqsymbol ([EcCoreLib.i_top], "Subtype") in + let theory = EcEnv.Theory.by_path ~mode:`Abstract cpath env in - let cth = exit_r ~pempty (add_clears clears scope) in - let ((cth, required), (name, _), scope) = cth in - let scope = List.fold_right require_loaded required scope in - let scope = ofold (fun cth scope -> bind ?import scope cth) scope cth in - (name, scope) + let renames : EcThCloning.renaming list = + match subtype.pst_rename with + | None -> [] + | Some (insub, val_) -> [ + (`All, (EcRegexp.regexp "val", EcRegexp.subst val_)); + (`All, (EcRegexp.regexp "insub", EcRegexp.subst insub)); + ] in - (* ------------------------------------------------------------------ *) - let bump_prelude (scope : scope) = - match scope.sc_prelude with - | `InPrelude, _ -> - { scope with sc_prelude = (`InPrelude, - { pr_env = env scope; - pr_required = scope.sc_required; }) } - | _ -> scope + let theory = theory.cth_items in + + let (proofs, scope) = + EcTheoryReplay.replay !subtype_hooks_ref + ~abstract:false ~override_locality:None ~incl:(Option.is_none cname) + ~clears:Sp.empty ~renames ~opath:cpath ~npath + evclone scope + (Option.value ~default:(EcPath.basename cpath) cname, theory, `Global) + in + let proofs = + List.pmap (fun axc -> + match axc.EcThCloning.axc_tac with + | None -> + Some (fst_map some axc.EcThCloning.axc_axiom, + axc.EcThCloning.axc_path, + axc.EcThCloning.axc_env) + | Some _ -> + (* tactic-bearing proofs require Tactics.process_r which + isn't available at this point (defined after Ty); they + are not produced by Subtype's evclone (which only + provides ev_global), so this branch is unreachable. *) + assert false) + proofs + in + Ax.add_defer scope proofs (* ------------------------------------------------------------------ *) - let import (scope : scope) (name : qsymbol) = - assert (scope.sc_pr_uc = None); + let add ?src:_ scope (tyd : ptydecl located) = + let loc = loc tyd in - match EcEnv.Theory.lookup_opt ~mode:`All name (env scope) with + let { pty_name = name; pty_tyvars = args; + pty_body = body; pty_locality = tyd_loca } = unloc tyd in + + check_name_available scope name; + let env = env scope in + let tyd_params, tyd_type = + match body with + | PTYD_Abstract tcs -> + let ue = TT.transtyvars env (loc, Some args) in + let tcs = List.map (TT.transtc env ue) tcs in + let tp = EcUnify.UniEnv.tparams ue in + tp, `Abstract tcs + + | PTYD_Alias bd -> + let ue = TT.transtyvars env (loc, Some args) in + let body = transty tp_tydecl env ue bd in + EcUnify.UniEnv.tparams ue, `Concrete body + + | PTYD_Datatype dt -> + let datatype = EHI.trans_datatype env (mk_loc loc (args,name)) dt in + let ty_from_ctor ctor = EcEnv.Ty.by_path ctor env in + let tparams, tydt = + try + ELI.check_positivity ty_from_ctor datatype; + ELI.datatype_as_ty_dtype datatype + with ELI.NonPositive ctx -> EHI.dterror loc env (EHI.DTE_NonPositive (unloc name, ctx)) + in + tparams, `Datatype tydt + + | PTYD_Record rt -> + let record = EHI.trans_record env (mk_loc loc (args,name)) rt in + let scheme = ELI.indsc_of_record record in + record.ELI.rc_tparams, `Record (scheme, record.ELI.rc_fields) + in + + bind scope (unloc name, + { tyd_params; tyd_type; tyd_loca; tyd_resolve = true; + tyd_subtype = None; }) + + (* ------------------------------------------------------------------ *) + let bindclass ?(import = true) (scope : scope) (x, tc) = + assert (scope.sc_pr_uc = None); + let item = EcTheory.mkitem ~import (EcTheory.Th_typeclass(x, tc)) in + { scope with sc_env = EcSection.add_item item scope.sc_env } + + (* ------------------------------------------------------------------ *) + let add_class (scope : scope) { pl_desc = tcd; pl_loc = loc } = + assert (scope.sc_pr_uc = None); + let lc = tcd.ptc_loca in + let name = unloc tcd.ptc_name in + let scenv = (env scope) in + + check_name_available scope tcd.ptc_name; + + let tclass = + (* Check typeclasses arguments *) + let ue = TT.transtyvars scenv (loc, tcd.ptc_params) in + + let uptcs = + let parent_ue = EcUnify.UniEnv.copy ue in + let uptcs = List.map + (fun (p, ren) -> + (TT.transtc scenv parent_ue p, + List.map (fun (s, t) -> (unloc s, unloc t)) ren)) + tcd.ptc_inth in + let subst = Tuni.subst + ~tw_uni:(EcUnify.UniEnv.tw_assubst parent_ue) + (EcUnify.UniEnv.close parent_ue) in + List.map (fun (tcp, ren) -> + ({ tcp with tc_args = List.map (EcCoreSubst.etyarg_subst subst) tcp.tc_args }, + ren)) + uptcs in + + (* The carrier's [tcs] should reference the class being declared + (so its own ops can be resolved via [Abs mypath, l=0]) and the + parent class is reachable via the ancestor chain. To make + [EcTypeClass.ancestors] work during axiom typing, we pre-bind + a stub typeclass record. The full record replaces the stub at + end of [add_class]. *) + let mypath = EcPath.pqname (path scope) name in + let stub_tc : tc_decl = { + tc_tparams = EcUnify.UniEnv.tparams ue; + tc_prts = uptcs; + tc_ops = []; + tc_axs = []; + tc_loca = lc; + } in + let scenv = + EcEnv.TypeClass.rebind name stub_tc scenv in + + let tc_self = + { tc_name = mypath; + tc_args = EcDecl.etyargs_of_tparams stub_tc.tc_tparams; } in + let asty = + { tyd_params = []; + tyd_type = `Abstract [tc_self]; + tyd_resolve = true; + tyd_loca = (lc :> locality); + tyd_subtype = None; } in + let scenv = EcEnv.Ty.bind name asty scenv in + + (* Check for duplicated field names *) + Msym.odup unloc (List.map fst tcd.ptc_ops) + |> oiter (fun (x, y) -> hierror ~loc:y.pl_loc + "duplicated operator name: `%s'" x.pl_desc); + Msym.odup unloc (List.map fst tcd.ptc_axs) + |> oiter (fun (x, y) -> hierror ~loc:y.pl_loc + "duplicated axiom name: `%s'" x.pl_desc); + + (* Check operators types *) + let operators = + let check1 (x, ty) = + let ue = EcUnify.UniEnv.copy ue in + let ty = transty tp_tydecl scenv ue ty in + let uidmap = + try EcUnify.UniEnv.close ue + with EcUnify.UninstanciateUni _ -> + hierror ~loc:x.pl_loc + "operator `%s' has free type/typeclass variables in its type. \ + Provide an explicit type instantiation (e.g. via `<:%s>`) to \ + fix the carrier." + (unloc x) (unloc tcd.ptc_name) + in + let tw_uni = EcUnify.UniEnv.tw_assubst ue in + let ty = ty_subst (Tuni.subst ~tw_uni uidmap) ty in + (EcIdent.create (unloc x), ty) + in + tcd.ptc_ops |> List.map check1 in + + (* Check axioms *) + let axioms = + let scenv = EcEnv.Var.bind_locals operators scenv in + let check1 (x, ax) = + let ue = EcUnify.UniEnv.copy ue in + let ax = trans_prop scenv ue ax in + let uidmap = + try EcUnify.UniEnv.close ue + with EcUnify.UninstanciateUni _ -> + hierror ~loc:x.pl_loc + "axiom `%s' is type-ambiguous: free type/typeclass variables \ + remain after typing. Provide an explicit type instantiation \ + (e.g. via `<:%s>`) to fix the carrier." + (unloc x) (unloc tcd.ptc_name) + in + let tw_uni = EcUnify.UniEnv.tw_assubst ue in + let fs = Tuni.subst ~tw_uni uidmap in + let ax = Fsubst.f_subst fs ax in + (unloc x, ax) + in + tcd.ptc_axs |> List.map check1 in + + (* Construct actual type-class *) + { tc_prts = uptcs; tc_tparams = EcUnify.UniEnv.tparams ue; + tc_ops = operators; tc_axs = axioms; tc_loca = lc; } + in + bindclass scope (name, tclass) + + (* ------------------------------------------------------------------ *) + let check_tci_operators env tcty ops reqs = + let ue = EcUnify.UniEnv.create (Some (fst tcty)) in + + let ops = + let tt1 m (x, (tvi, op)) = + if not (Mstr.mem (unloc x) reqs) then + hierror ~loc:x.pl_loc "invalid operator name: `%s'" (unloc x); + + let tvi = List.map (TT.transty tp_tydecl env ue) tvi in + let tvi = List.map (fun ty -> (Some ty, None)) tvi in + let selected = + EcUnify.select_op ~filter:(fun _ -> EcDecl.is_oper) + (Some (EcUnify.TVIunamed tvi)) env (unloc op) ue [] + in + let op = + match selected with + | [] -> hierror ~loc:op.pl_loc "unknown operator" + | op1 :: op2 :: _ -> + hierror ~loc:op.pl_loc + "ambiguous operator (%s / %s)" + (EcPath.tostring (fst (proj4_1 op1))) + (EcPath.tostring (fst (proj4_1 op2))) + | [((p, opparams), opty, subue, _)] -> + let subst = Tuni.subst + ~tw_uni:(EcUnify.UniEnv.tw_assubst subue) + (EcUnify.UniEnv.assubst subue) in + let opty = ty_subst subst opty in + let opparams = List.map (EcCoreSubst.etyarg_subst subst) opparams in + ((p, opparams), opty) + + in + Mstr.change + (function + | None -> Some (x.pl_loc, op) + | Some _ -> hierror ~loc:(x.pl_loc) + "duplicated operator name: `%s'" (unloc x)) + (unloc x) m + in + List.fold_left tt1 Mstr.empty ops + in + Mstr.iter + (fun x (req, _) -> + if req && not (Mstr.mem x ops) then + hierror "no definition for operator `%s'" x) + reqs; + Mstr.fold + (fun x (_, ty) m -> + match Mstr.find_opt x ops with + | None -> m + | Some (loc, ((p, opparams), opty)) -> + if not (EcReduction.EqTest.for_type env ty opty) then begin + let ppe = EcPrinting.PPEnv.ofenv env in + hierror ~loc +"invalid type for operator `%s':@\n\ +\ - expected: %a@\n\ +\ - got : %a" + x (EcPrinting.pp_type ppe) ty (EcPrinting.pp_type ppe) opty + end; Mstr.add x (p, opparams) m) + reqs Mstr.empty + + (* ------------------------------------------------------------------ *) + let check_tci_axioms ?(tparams = []) scope mode axs reqs lc = + let rmap = Mstr.of_list reqs in + let symbs, axs = + List.map_fold + (fun m (x, t) -> + if not (Mstr.mem (unloc x) rmap) then + hierror ~loc:x.pl_loc "invalid axiom name: `%s'" (unloc x); + if Sstr.mem (unloc x) m then + hierror ~loc:(x.pl_loc) "duplicated axiom name: `%s'" (unloc x); + (Sstr.add (unloc x) m, (unloc x, t, Mstr.find (unloc x) rmap))) + Sstr.empty axs in + + let interactive = + List.pmap + (fun (x, req) -> + if not (Mstr.mem x symbs) then + let ax = { + ax_tparams = tparams; + ax_spec = req; + ax_kind = `Lemma; + ax_loca = lc; + ax_smt = false; + } in Some ((None, ax), EcPath.psymbol x, scope.sc_env) + else None) + reqs in + List.iter + (fun (x, pt, f) -> + let x = "$" ^ x in + let t = { pt_core = pt; pt_intros = []; } in + let t = { pl_loc = pt.pl_loc; pl_desc = Pby (Some [t]) } in + let t = { pt_core = t; pt_intros = []; } in + let ax = { + ax_tparams = tparams; + ax_spec = f; + ax_kind = `Lemma; + ax_smt = false; + ax_loca = lc; + } in + + let pucflags = { puc_smt = true; puc_local = false; } in + let pucflags = (([], None), pucflags) in + let check = Check_mode.check scope.sc_options in + + let escope = scope in + let escope = Ax.start_lemma escope pucflags check ~name:x (ax, None) in + let escope = Tactics.proof escope in + let escope = snd (Tactics.process_r ~reloc:x false mode escope [t]) in + ignore (Ax.save_r escope)) + axs; + interactive + + (* ------------------------------------------------------------------ *) + let get_ring_field_op (name : string) (symbols : (path * etyarg list) Mstr.t) = + Option.map + (fun (p, tys) -> assert (List.is_empty tys); p) + (Mstr.find_opt name symbols) + + let ring_of_symmap env ty kind symbols = + { r_type = ty; + r_zero = oget (get_ring_field_op "rzero" symbols); + r_one = oget (get_ring_field_op "rone" symbols); + r_add = oget (get_ring_field_op "add" symbols); + r_opp = (get_ring_field_op "opp" symbols); + r_mul = oget (get_ring_field_op "mul" symbols); + r_exp = (get_ring_field_op "expr" symbols); + r_sub = (get_ring_field_op "sub" symbols); + r_kind = kind; + r_embed = + (match get_ring_field_op "ofint" symbols with + | None when EcReduction.EqTest.for_type env ty tint -> `Direct + | None -> `Default | Some p -> `Embed p); } + + let addring ~import (scope : scope) mode (kind, { pl_desc = tci; pl_loc = loc }) = + let env = env scope in + if not (EcAlgTactic.is_module_loaded env) then + hierror "load AlgTactic/Ring first"; + + let ty = + let ue = TT.transtyvars env (loc, Some (fst tci.pti_type)) in + let ty = transty tp_tydecl env ue (snd tci.pti_type) in + assert (EcUnify.UniEnv.closed ue); + let uidmap = EcUnify.UniEnv.close ue in + (EcUnify.UniEnv.tparams ue, ty_subst (Tuni.subst uidmap) ty) + in + + if not (List.is_empty (fst ty)) then + hierror "ring instances cannot be polymorphic"; + + let symbols = EcAlgTactic.ring_symbols env kind (snd ty) in + let symbols = Mstr.of_list symbols in + let symbols = check_tci_operators env ty tci.pti_ops symbols in + let cr = ring_of_symmap env (snd ty) kind symbols in + let axioms = EcAlgTactic.ring_axioms env cr in + let lc = (tci.pti_loca :> locality) in + let inter = check_tci_axioms scope mode tci.pti_axs axioms lc in + + let instance = EcTheory. + { tci_params = fst ty + ; tci_type = snd ty + ; tci_instance = `Ring cr + ; tci_local = (tci.pti_loca :> locality) + ; tci_parents = [] } in + + let scope = + let item = EcTheory.Th_instance (None, instance) in + let item = EcTheory.mkitem ~import item in + { scope with sc_env = EcSection.add_item item scope.sc_env } in + + Ax.add_defer scope inter + + (* ------------------------------------------------------------------ *) + let field_of_symmap env ty symbols = + { f_ring = ring_of_symmap env ty `Integer symbols; + f_inv = oget (get_ring_field_op "inv" symbols); + f_div = get_ring_field_op "div" symbols; } + + let addfield ~import (scope : scope) mode { pl_desc = tci; pl_loc = loc; } = + let env = env scope in + if not (EcAlgTactic.is_module_loaded env) then + hierror "load AlgTactic/Ring first"; + + let ty = + let ue = TT.transtyvars env (loc, Some (fst tci.pti_type)) in + let ty = transty tp_tydecl env ue (snd tci.pti_type) in + assert (EcUnify.UniEnv.closed ue); + let uidmap = EcUnify.UniEnv.close ue in + (EcUnify.UniEnv.tparams ue, ty_subst (Tuni.subst uidmap) ty) + in + + if not (List.is_empty (fst ty)) then + hierror "field instances cannot be polymorphic"; + + let symbols = EcAlgTactic.field_symbols env (snd ty) in + let symbols = Mstr.of_list symbols in + let symbols = check_tci_operators env ty tci.pti_ops symbols in + let cr = field_of_symmap env (snd ty) symbols in + let axioms = EcAlgTactic.field_axioms env cr in + let lc = (tci.pti_loca :> locality) in + let inter = check_tci_axioms scope mode tci.pti_axs axioms lc; in + + let instance = EcTheory. + { tci_params = fst ty + ; tci_type = snd ty + ; tci_instance = `Field cr + ; tci_local = (tci.pti_loca :> locality) + ; tci_parents = [] } in + + let scope = + let item = EcTheory.Th_instance (None, instance) in + let item = EcTheory.mkitem ~import item in + { scope with sc_env = EcSection.add_item item scope.sc_env } in + + Ax.add_defer scope inter + + (* ------------------------------------------------------------------ *) + let symbols_of_tc (_env : EcEnv.env) ((tparams, ty) : ty_params * ty) (tcp, tc) = + (* The instance's tparams are the same idents we'll later resolve + op-clause RHSs against (via [check_tci_operators], which creates + a [unienv] seeded with these tparams). Build the substitution + binding each tparam to itself (with appropriate TC witnesses on + the original ident) — DO NOT freshen, otherwise the expected op + type uses [c_fresh] while the resolved op uses the instance's + [c_orig], and [EqTest.for_type] rejects them as different + (printing identically as ['a] but with different idents). *) + let subst = + List.fold_left + (fun s (x, tcs) -> + let tcw = + List.mapi (fun i _ -> + EcAst.TCIAbstract { + support = `Var x; offset = i; lift = []; + }) tcs + in EcSubst.add_tyvar s x (EcTypes.tvar x, tcw)) + EcSubst.empty tparams in + let subst = EcSubst.add_tydef subst tcp.tc_name ([], ty, []) in + let subst = + List.fold_left + (fun subst (a, ty) -> EcSubst.add_tyvar subst a ty) + subst (List.combine (List.fst tc.tc_tparams) tcp.tc_args) in + + List.map (fun (x, opty) -> + (EcIdent.name x, (true, EcSubst.subst_ty subst opty))) + tc.tc_ops + + (* ------------------------------------------------------------------ *) + let add_generic_instance + ~import (scope : scope) mode { pl_desc = tci; pl_loc = loc; } + = + let (typarams, _) as ty = + let ue = TT.transtyvars (env scope) (loc, Some (fst tci.pti_type)) in + let ty = transty tp_tydecl (env scope) ue (snd tci.pti_type) in + assert (EcUnify.UniEnv.closed ue); + ( + EcUnify.UniEnv.tparams ue, + ty_subst (Tuni.subst (EcUnify.UniEnv.close ue)) ty + ) + in + + let tcp = + let ue = EcUnify.UniEnv.create (Some typarams) in + let tcp = TT.transtc (env scope) ue tci.pti_tc in + let subst = Tuni.subst (EcUnify.UniEnv.close ue) in + { tcp with tc_args = List.map (EcCoreSubst.etyarg_subst subst) tcp.tc_args } in + + (* Walk the parent DAG with cumulative op renamings: [(tcp, []); + (parent_1, ren_1); ...]. The empty renaming on [tcp] is identity: + each ancestor op [n] maps to a local op [n]. Renamings on parent + edges (declared via [<: P with { ... }]) compose along the path, + so a renamed grandparent op resolves to a local op. *) + let chain = EcTypeClass.ancestors_with_renaming (env scope) tcp in + let chain_decls = + List.map + (fun (anc, ren) -> + (anc, EcEnv.TypeClass.by_path anc.tc_name (env scope), ren)) + chain in + + let lookup_ren ren n = odfl n (Mstr.find_opt n (Mstr.of_list ren)) in + + (* Build the set of expected operators. Immediate-class ops (no + renaming applied) are required. Ancestor ops are mapped through + the cumulative renaming to local op names; if an ancestor op's + local name is already in [tcsyms] (e.g. via the immediate class + or via another path), keep the existing entry. *) + let tcsyms = + match chain_decls with + | [] -> assert false + | (tcp_self, tc_self, _) :: rest -> + let immediate = symbols_of_tc (env scope) ty (tcp_self, tc_self) in + let immediate_set = Sstr.of_list (List.map fst immediate) in + let parents = + List.concat_map + (fun (anc, anc_decl, ren) -> + symbols_of_tc (env scope) ty (anc, anc_decl) + |> List.map (fun (n, (_, opty)) -> + (lookup_ren ren n, (false, opty)))) + rest in + let parents = + List.filter + (fun (n, _) -> not (Sstr.mem n immediate_set)) + parents in + Mstr.of_list (immediate @ parents) in + let symbols = check_tci_operators (env scope) ty tci.pti_ops tcsyms in + + (* For any ancestor op (after renaming) the user didn't provide, + look up an existing instance of that ancestor on the same + carrier and reuse its realisation. *) + let existing_anc_symbols anc = + List.fold_left (fun acc (_, tci_existing) -> + match acc with + | Some _ -> acc + | None -> + match tci_existing.EcTheory.tci_instance with + | `General (tgp, Some sym) + when EcPath.p_equal tgp.tc_name anc.tc_name + && EcReduction.EqTest.for_type + (env scope) tci_existing.EcTheory.tci_type (snd ty) -> + Some sym + | _ -> None) + None (EcEnv.TcInstance.get_all (env scope)) in + let symbols = + List.fold_left + (fun symbols (anc, anc_decl, ren) -> + let missing = + List.filter (fun (id, _) -> + not (Mstr.mem (lookup_ren ren (EcIdent.name id)) symbols)) + anc_decl.tc_ops in + if missing = [] then symbols + else + match existing_anc_symbols anc with + | None -> + let id, _ = List.hd missing in + hierror "no definition for operator `%s'" (EcIdent.name id) + | Some sym -> + List.fold_left + (fun symbols (id, _) -> + let n = EcIdent.name id in + let local_n = lookup_ren ren n in + match Mstr.find_opt n sym with + | Some s -> Mstr.add local_n s symbols + | None -> symbols) + symbols missing) + symbols (List.tl chain_decls) in + + (* Phase B coherence check: when a chain entry derives an instance + of [anc] on the carrier and an instance for the same ancestor + on the same carrier already exists in scope, the two must + agree on every op realisation. Catches the case where a user + declares `instance addgroup with int { ... }` and later + `instance comring with int { ... }` with conflicting +. *) + List.iter + (fun (anc, anc_decl, ren) -> + match existing_anc_symbols anc with + | None -> () + | Some existing_sym -> + List.iter + (fun (id, _) -> + let n = EcIdent.name id in + let local_n = lookup_ren ren n in + match Mstr.find_opt local_n symbols, Mstr.find_opt n existing_sym with + | Some (p1, _), Some (p2, _) when not (EcPath.p_equal p1 p2) -> + hierror + "diamond coherence violation: registering an instance \ + of `%s' on this carrier requires op `%s' to be `%s', \ + but an existing instance binds it to `%s'" + (EcPath.tostring anc.tc_name) + n + (EcPath.tostring p1) + (EcPath.tostring p2) + | _ -> ()) + anc_decl.tc_ops) + chain_decls; + + (* Pre-compute the path each chain entry will receive when it is + registered as a [Th_instance] below. We need these paths up + front so the [add_tydef] substitution can reference them as + concrete witnesses — the inherited axiom bodies use + [`Abs anc.tc_name; offset = 0] which, in the class body's + semantics, refers to "the carrier-as-this-class". After + substituting the carrier with [ty], that needs to point at + the instance for [ty] of [anc] — i.e. exactly the path we are + about to register. *) + (* For each chain entry, first check whether an existing instance in + the env already realises [(anc, snd ty)] with op-symbols matching + what this declaration would produce. If so, reuse its path rather + than register a duplicate (which would diverge witnesses and + violate one-canonical-instance-per-(class, carrier)). The + returned [name option] is [None] when reusing — the chain + registration loop below skips those entries. *) + let find_existing_chain_entry (anc : typeclass) (anc_decl : tc_decl) ren = + let expected = + List.fold_left + (fun m (id, _) -> + let n = EcIdent.name id in + let local = lookup_ren ren n in + match Mstr.find_opt local symbols with + | Some s -> Mstr.add n s m + | None -> m) + Mstr.empty anc_decl.tc_ops in + let same_symbols (existing_syms : (path * etyarg list) Mstr.t) = + Mstr.for_all + (fun n (p, _) -> + match Mstr.find_opt n existing_syms with + | Some (p', _) -> EcPath.p_equal p p' + | None -> false) + expected in + (* Carrier-type comparison must be alpha-equivalent (ignore tparam + identity), since the existing instance's tparams have their own + fresh idents that don't match the user's tparams here. Use + [EcTypeClass.ty_match] with the existing instance's tparams as + pattern variables. *) + let same_carrier (tci_existing : EcTheory.tcinstance) = + try + let _ : ty option Mid.t = + EcTypeClass.ty_match (env scope) + (List.fst tci_existing.EcTheory.tci_params) + ~pattern:tci_existing.EcTheory.tci_type + ~ty:(snd ty) + in true + with EcTypeClass.NoMatch -> false in + List.opick + (fun (path_opt, tci_existing) -> + match path_opt with + | None -> None + | Some p -> + if same_carrier tci_existing + && (match tci_existing.EcTheory.tci_instance with + | `General (anc', Some syms) -> + EcPath.p_equal anc'.tc_name anc.tc_name + && same_symbols syms + | _ -> false) + then Some p else None) + (EcEnv.TcInstance.get_all (env scope)) in + let chain_paths_pre = + List.mapi + (fun idx (anc, anc_decl, ren) -> + match find_existing_chain_entry anc anc_decl ren with + | Some existing_path -> (None, existing_path) + | None -> + let name = + if idx = 0 then + match tci.pti_name with + | Some name -> unloc name + | None -> + Printf.sprintf "%s_%d" + (EcPath.basename anc.EcAst.tc_name) (EcUid.unique ()) + else + Printf.sprintf "%s_%d" + (EcPath.basename anc.EcAst.tc_name) (EcUid.unique ()) in + (Some name, EcPath.pqname (path scope) name)) + chain_decls in + + (* Build a substitution mapping every op-ident along the chain to its + chosen realisation on [ty]. For each ancestor the renaming maps + its op names to local op names (via [lookup_ren ren]). *) + let subst, _ = + (* The chain may contain entries sharing a TC name (under + different renamings). [add_tydef] asserts no double-binding, + so we track which TC names we've already added and skip. *) + List.fold_lefti + (fun (subst, seen) idx (anc, anc_decl, ren) -> + let seen, subst = + if EcPath.Sp.mem anc.tc_name seen then (seen, subst) + else + (* The class body referenced its carrier as + [`Abs anc.tc_name; offset = 0; …] (a self-reference, + since [anc]'s own [tcs] contains [anc] itself). After + substituting the carrier with [ty], that reference + must point to the instance for [ty] of [anc] — which + is the chain entry we are about to register. We use + its pre-computed path. The [`Abs] case of + [subst_tcw] then bumps the body's lift onto this + concrete witness, walking [tci_parents] correctly. *) + let _, inst_path = List.nth chain_paths_pre idx in + (* For parametric carriers ([instance C with ['a <: …] (T 'a)]), + the chain instance is registered with the same tparams as the + user's instance. Its witness must therefore re-apply those + tparams as etyargs (each carrying its own abstract TC + witnesses), not [], or [tc_reduce] will hit a + [tci_params]/[etyargs] length mismatch when the instance is + later consulted via this witness. *) + let self_etyargs = + List.map + (fun (x, tcs) -> + let tcws = + List.mapi (fun i _ -> + EcAst.TCIAbstract { + support = `Var x; offset = i; lift = []; + }) tcs + in (EcTypes.tvar x, tcws)) + (fst ty) in + let self_witness = + TCIConcrete { path = inst_path; etyargs = self_etyargs; lift = [] } in + (EcPath.Sp.add anc.tc_name seen, + EcSubst.add_tydef subst anc.tc_name + ([], snd ty, [self_witness])) in + let subst = + List.fold_left + (fun subst (a, ty) -> EcSubst.add_tyvar subst a ty) + subst + (List.combine (List.fst anc_decl.tc_tparams) anc.tc_args) in + let subst = + List.fold_left + (fun subst (opname, ty) -> + let local = lookup_ren ren (EcIdent.name opname) in + let oppath, optys = Mstr.find local symbols in + let op = + EcFol.f_op_tc + oppath + (List.map (EcSubst.subst_etyarg subst) optys) + (EcSubst.subst_ty subst ty) + in EcSubst.add_flocal subst opname op) + subst anc_decl.tc_ops in + (subst, seen)) + (EcSubst.empty, EcPath.Sp.empty) chain_decls in + + let lc = (tci.pti_loca :> locality) in + + (* Compose two renamings (matches the version in [ecTypeClass.ml] + which is used to build the chain). [outer] is declared on the + parent edge; [inner] is the cumulative renaming on this entry. + Result maps grandparent op names to local op names. *) + let compose_ren ~outer ~inner = + let inner_map = Mstr.of_list inner in + let from_outer = + List.map + (fun (gp_name, p_name) -> + let c_name = odfl p_name (Mstr.find_opt p_name inner_map) in + (gp_name, c_name)) + outer in + let outer_p_names = + List.fold_left (fun s (_, p) -> Sstr.add p s) Sstr.empty outer in + let outer_gp_names = + List.fold_left (fun s (gp, _) -> Sstr.add gp s) Sstr.empty outer in + let from_inner = + List.filter_map + (fun (p_name, c_name) -> + if Sstr.mem p_name outer_p_names || Sstr.mem p_name outer_gp_names + then None + else Some (p_name, c_name)) + inner in + from_outer @ from_inner in + let ren_eq r1 r2 = + List.length r1 = List.length r2 + && List.for_all2 (fun (a, b) (c, d) -> a = c && b = d) r1 r2 in + + (* Register one instance per ancestor chain entry, in REVERSE + BFS order (leaves before children) so that when a child entry + is registered, its parents' paths are already known. The + [chain_paths] array uses the pre-computed paths from + [chain_paths_pre] so that proof-obligation substitutions can + reference them ahead of registration. We register BEFORE + [check_tci_axioms] so that the substituted obligation's + concrete witnesses (which point at these paths) resolve + through the env when [tc_reduce] fires. *) + let chain_paths = + Array.of_list + (List.map (fun (_, p) -> Some p) chain_paths_pre) in + let scope = + List.fold_lefti + (fun scope rev_idx (anc, anc_decl, ren) -> + let idx = (List.length chain_decls) - 1 - rev_idx in + let name_opt, _ = List.nth chain_paths_pre idx in + match name_opt with + | None -> + (* Chain entry reuses an existing instance — don't register + a duplicate. The pre-existing instance already provides + this ancestor's ops + axioms, and its path is what + [chain_paths_pre]/[chain_paths] return for [idx]. *) + scope + | Some name -> + let anc_symbols = + List.fold_left + (fun m (id, _) -> + let n = EcIdent.name id in + let local = lookup_ren ren n in + match Mstr.find_opt local symbols with + | Some s -> Mstr.add n s m + | None -> m) + Mstr.empty anc_decl.tc_ops in + (* Find this entry's parent chain entries: for each parent + of [anc] (in [anc_decl.tc_prts]), the parent chain entry + has the same TC and the renaming composed with [ren]. *) + let parents = + List.map + (fun (p_tc, p_ren) -> + let target_ren = compose_ren ~outer:p_ren ~inner:ren in + let rec find i = function + | [] -> None + | (a, _, r) :: rest -> + if EcPath.p_equal a.EcAst.tc_name p_tc.EcAst.tc_name + && ren_eq r target_ren + then chain_paths.(i) + else find (i + 1) rest + in find 0 chain_decls) + anc_decl.tc_prts in + let parents = List.pmap (fun x -> x) parents in + let instance = EcTheory. + { tci_params = fst ty + ; tci_type = snd ty + ; tci_instance = `General (anc, Some anc_symbols) + ; tci_local = lc + ; tci_parents = parents } in + let item = EcTheory.Th_instance (Some name, instance) in + let item = EcTheory.mkitem ~import item in + { scope with sc_env = EcSection.add_item item scope.sc_env }) + scope (List.rev chain_decls) in + + (* Auto-skip a chain entry's axioms if a previously-declared + instance for the same (TC name, carrier) already proves them. + Symbols-equivalent means: for every op declared in the + ancestor, both the existing instance and the chain entry's + expected symbol map agree on the underlying op-path. + This is what lets [instance addmonoid with int] (with just ops, + no proofs) succeed when [instance monoid with int] is already + discharged: addmonoid's monoid-axiom obligations are + discharged by the existing monoid instance. The chain entries + we register in this declaration are excluded by path. *) + (* Only freshly-registered paths count as "self": reused paths + refer to instances that pre-existed, and we want + [already_discharged] to count them as the discharger. *) + let chain_self_paths = + List.filter_map + (fun (name_opt, p) -> + if Option.is_some name_opt then Some p else None) + chain_paths_pre + |> EcPath.Sp.of_list in + let already_discharged (anc : typeclass) (anc_decl : tc_decl) (ren : _) : bool = + let expected = + List.fold_left + (fun m (id, _) -> + let n = EcIdent.name id in + let local = lookup_ren ren n in + match Mstr.find_opt local symbols with + | Some s -> Mstr.add n s m + | None -> m) + Mstr.empty anc_decl.tc_ops in + let same_symbols (existing_syms : (path * etyarg list) Mstr.t) = + Mstr.for_all + (fun n (p, _) -> + match Mstr.find_opt n existing_syms with + | Some (p', _) -> EcPath.p_equal p p' + | None -> false) + expected in + List.exists + (fun (path_opt, tci) -> + let is_other = + match path_opt with + | Some path -> not (EcPath.Sp.mem path chain_self_paths) + | None -> true in + is_other + && EcReduction.EqTest.for_type + (env scope) tci.EcTheory.tci_type (snd ty) + && (match tci.EcTheory.tci_instance with + | `General (anc', Some syms) -> + EcPath.p_equal anc'.tc_name anc.tc_name + && same_symbols syms + | _ -> false)) + (EcEnv.TcInstance.get_all (env scope)) in + + (* Build the proof-obligation list (deduped by axiom name across + chain entries) and check the user's tactics against it, now + that the chain instances are bound in the env so [tc_reduce] + can walk through their pre-computed paths. *) + let axioms = + let _, axs = + List.fold_left + (fun (seen, acc) (anc, anc_decl, ren) -> + if already_discharged anc anc_decl ren then (seen, acc) + else + List.fold_left + (fun (seen, acc) (name, ax) -> + if Sstr.mem name seen then (seen, acc) + else + (Sstr.add name seen, + (name, EcSubst.subst_form subst ax) :: acc)) + (seen, acc) + anc_decl.tc_axs) + (Sstr.empty, []) chain_decls in + List.rev axs in + let inter = check_tci_axioms ~tparams:(fst ty) scope mode tci.pti_axs axioms lc in + + Ax.add_defer scope inter + + (* ------------------------------------------------------------------ *) + let add_instance + ?(import = true) (scope : scope) mode ({ pl_desc = tci } as toptci) + = + match unloc (fst tci.pti_tc) with + | ([], "bring") -> begin + if EcUtils.is_some tci.pti_args then + hierror "unsupported-option"; + addring ~import scope mode (`Boolean, toptci) + end + + | ([], "ring") -> begin + let kind = + match tci.pti_args with + | None -> `Integer + | Some (`Ring (c, p)) -> + if odfl false (c |> omap (fun c -> c <^ BI.of_int 2)) then + hierror "invalid coefficient modulus"; + if odfl false (p |> omap (fun p -> p <^ BI.of_int 2)) then + hierror "invalid power modulus"; + if opt_equal BI.equal c (Some (BI.of_int 2)) + && opt_equal BI.equal p (Some (BI.of_int 2)) + then `Boolean + else `Modulus (c, p) + in addring ~import scope mode (kind, toptci) + end + + | ([], "field") -> addfield ~import scope mode toptci + + | _ -> + if EcUtils.is_some tci.pti_args then + hierror "unsupported-option"; + add_generic_instance ~import scope mode toptci +end + +(* -------------------------------------------------------------------- *) +module Theory = struct + open EcTheory + + exception TopScope + + (* ------------------------------------------------------------------ *) + let bind ?(import = true) (scope : scope) (cth : thloaded) = + assert (scope.sc_pr_uc = None); + { scope with + sc_env = EcSection.add_th ~import cth scope.sc_env } + + (* ------------------------------------------------------------------ *) + let required (scope : scope) (name : required_info) = + assert (scope.sc_pr_uc = None); + List.exists (fun x -> + if x.rqd_name = name.rqd_name then ( + (* FIXME: raise an error message *) + assert (x.rqd_digest = name.rqd_digest); + true) + else false) + scope.sc_required + + (* ------------------------------------------------------------------ *) + let mark_as_direct (scope : scope) (name : symbol) = + let for1 rq = + if rq.rqd_name = name + then { rq with rqd_direct = true } + else rq + in { scope with sc_required = List.map for1 scope.sc_required } + + (* ------------------------------------------------------------------ *) + let enter ?src:_ (scope : scope) (mode : thmode) (name : symbol) = + assert (scope.sc_pr_uc = None); + subscope scope mode name + + (* ------------------------------------------------------------------ *) + let rec require_loaded (id : required_info) scope = + if required scope id then + scope + else + match Msym.find_opt id.rqd_name scope.sc_loaded with + | Some (rth, ids) -> + let scope = List.fold_right require_loaded ids scope in + let env = EcSection.require rth scope.sc_env in + { scope with + sc_env = env; + sc_required = id :: scope.sc_required; } + + | None -> assert false + + (* ------------------------------------------------------------------ *) + let update_with_required ~(dst : scope) ~(src : scope) = + let dst = + let sc_loaded = + Msym.union + (fun _ x y -> assert (x ==(*phy*) y); Some x) + dst.sc_loaded src.sc_loaded + in { dst with sc_loaded } + in List.fold_right require_loaded src.sc_required dst + + (* ------------------------------------------------------------------ *) + let add_clears clears scope = + let clears = + let for1 = function + | None -> EcEnv.root (env scope) + | Some { pl_loc = loc; pl_desc = (xs, x) as q } -> + let xp = EcEnv.root (env scope) in + let xp = EcPath.pqname (EcPath.extend xp xs) x in + if is_none (EcEnv.Theory.by_path_opt xp (env scope)) then + hierror ~loc "unknown theory: `%s`" (string_of_qsymbol q); + xp + in List.map for1 clears + in { scope with sc_clears = scope.sc_clears @ clears } + + (* -------------------------------------------------------------------- *) + let exit_r ?pempty (scope : scope) = + match scope.sc_top with + | None -> raise TopScope + | Some sup -> + let clears = scope.sc_clears in + let _, cth, _ = EcSection.exit_theory ?pempty ~clears scope.sc_env in + let loaded = scope.sc_loaded in + let required = scope.sc_required in + let sup = { sup with sc_loaded = loaded; } in + ((cth, required), scope.sc_name, sup) + + (* ------------------------------------------------------------------ *) + let exit ?(import = true) ?(pempty = `ClearOnly) ?(clears =[]) (scope : scope) = + assert (scope.sc_pr_uc = None); + + let cth = exit_r ~pempty (add_clears clears scope) in + let ((cth, required), (name, _), scope) = cth in + let scope = List.fold_right require_loaded required scope in + let scope = ofold (fun cth scope -> bind ~import scope cth) scope cth in + (name, scope) + + (* ------------------------------------------------------------------ *) + let bump_prelude (scope : scope) = + match scope.sc_prelude with + | `InPrelude, _ -> + { scope with sc_prelude = (`InPrelude, + { pr_env = env scope; + pr_required = scope.sc_required; }) } + | _ -> scope + + (* ------------------------------------------------------------------ *) + let import (scope : scope) (name : qsymbol) = + assert (scope.sc_pr_uc = None); + + match EcEnv.Theory.lookup_opt ~mode:`All name (env scope) with | None -> hierror "cannot import the non-existent theory `%s'" @@ -2012,8 +2792,7 @@ module Theory = struct "end-of-file while processing proof %s" (fst scope.sc_name) (* -------------------------------------------------------------------- *) - let require_start (scope : scope) (thname : symbol) (mode : thmode) - : scope = + let require_start (scope : scope) (thname : symbol) (mode : thmode) : scope = let new_ = enter (for_loading scope) mode thname `Global in { new_ with sc_env = EcSection.astop new_.sc_env } @@ -2036,6 +2815,7 @@ module Theory = struct end else match Msym.find_opt name.rqd_name scope.sc_loaded with | Some _ -> require_loaded name scope + | None -> try let imported = require_start scope name.rqd_name mode in @@ -2051,7 +2831,6 @@ module Theory = struct raise (ImportError (None, name.rqd_name, e)) end - (* -------------------------------------------------------------------- *) let required scope = scope.sc_required (* -------------------------------------------------------------------- *) @@ -2060,7 +2839,7 @@ module Theory = struct let thpath, _ = ofdfl (fun () -> hierror ~loc:(loc target) "unknown theory: %a" pp_qsymbol (unloc target) ) thpath in - let item = EcTheory.mkitem ~import:true (Th_alias (unloc name, thpath)) in + let item = EcTheory.mkitem ~import:true (EcTheory.Th_alias (unloc name, thpath)) in { scope with sc_env = EcSection.add_item item scope.sc_env } end @@ -2070,599 +2849,151 @@ module Section = struct module T = EcTheory let enter (scope : scope) (name : psymbol option) = - assert (scope.sc_pr_uc = None); - { scope with - sc_env = EcSection.enter_section (omap unloc name) scope.sc_env } - - let exit (scope : scope) (name : psymbol option) = - let sc_env = EcSection.exit_section (omap unloc name) scope.sc_env in - { scope with sc_env } -end - -(* -------------------------------------------------------------------- *) -module Reduction = struct - (* FIXME: section -> allow "local" flag *) - let add_reduction scope (opts, reds) = - check_state `InTop "hint simplify" scope; - - let rules = - let for1 idx name = - let idx = odfl 0 idx in - let ax_p = EcEnv.Ax.lookup_path (unloc name) (env scope) in - let opts = EcTheory.{ - ur_delta = List.mem `Delta opts; - ur_eqtrue = List.mem `EqTrue opts; - } in - - let red_info = - EcReduction.User.compile ~opts ~prio:idx (env scope) ax_p in - (ax_p, opts, Some red_info) in - - let rules = List.map (fun (xs, idx) -> List.map (for1 idx) xs) reds in - List.flatten rules - - in - - let item = EcTheory.mkitem ~import:true (EcTheory.Th_reduction rules) in - { scope with sc_env = EcSection.add_item item scope.sc_env } -end - -(* -------------------------------------------------------------------- *) -module Cloning = struct - (* ------------------------------------------------------------------ *) - open EcTheory - open EcThCloning - - module C = EcThCloning - module R = EcTheoryReplay - - (* ------------------------------------------------------------------ *) - let hooks ~(override_locality: is_local option) : scope R.ovrhooks = - let thexit sc ~import pempty = - snd (Theory.exit ~import ?clears:None ~pempty sc) in - let add_item scope ~import item = - let item = EcTheory.mkitem ~import item in - { scope with sc_env = EcSection.add_item ~override_locality item scope.sc_env } in - { R.henv = (fun scope -> scope.sc_env); - R.hadd_item = add_item; - R.hthenter = Theory.enter; - R.hthexit = thexit; - R.herr = (fun ?loc -> hierror ?loc "%s"); } - - (* ------------------------------------------------------------------ *) - module Options = struct - open EcTheoryReplay - - let default = { clo_abstract = false; } - - let merge1 opts (b, (x : theory_cloning_option)) = - match x with - | `Abstract -> { opts with clo_abstract = b; } - - let merge opts (specs : theory_cloning_options) = - List.fold_left merge1 opts specs - end - - (* ------------------------------------------------------------------ *) - let replay_proofs (scope : scope) (mode : Tactics.proofmode) (proofs : _) = - proofs |> List.pmap (fun axc -> - match axc.C.axc_tac with - | None -> - Some (fst_map some axc.C.axc_axiom, axc.C.axc_path, axc.C.axc_env) - - | Some pt -> - let t = { pt_core = pt; pt_intros = []; } in - let t = { pl_loc = pt.pl_loc; pl_desc = Pby (Some [t]); } in - let t = { pt_core = t; pt_intros = []; } in - let (x, ax) = axc.C.axc_axiom in - - let pucflags = { puc_smt = true; puc_local = false; } in - let pucflags = (([], None), pucflags) in - let check = Check_mode.check scope.sc_options in - - let escope = { scope with sc_env = axc.C.axc_env; } in - let escope = Ax.start_lemma escope pucflags check ~name:x (ax, None) in - let escope = Tactics.proof escope in - let escope = snd (Tactics.process_r ~reloc:x false mode escope [t]) in - ignore (Ax.save_r escope); None - ) - - (* ------------------------------------------------------------------ *) - let clone (scope : scope) mode (thcl : theory_cloning) = - assert (scope.sc_pr_uc = None); - - let { cl_name = name; - cl_theory = (opath, oth); - cl_clone = ovrds; - cl_rename = rnms; - cl_ntclr = ntclr; } - - = C.clone scope.sc_env thcl in - - let incl = thcl.pthc_import = Some `Include in - let opts = Options.merge Options.default thcl.pthc_opts in - - if thcl.pthc_import = Some `Include && opts.R.clo_abstract then - hierror "cannot include an abstract theory"; - if thcl.pthc_import = Some `Include && EcUtils.is_some thcl.pthc_name then - hierror "cannot give an alias to an included clone"; - - let cpath = EcEnv.root (env scope) in - let npath = if incl then cpath else EcPath.pqname cpath name in - - let (proofs, scope) = - EcTheoryReplay.replay (hooks ~override_locality:thcl.pthc_local) - ~abstract:opts.R.clo_abstract ~override_locality:thcl.pthc_local ~incl - ~clears:ntclr ~renames:rnms ~opath ~npath ovrds - scope (name, false, oth.cth_items, oth.cth_loca) - in - - let proofs = replay_proofs scope mode proofs in - - let scope = - thcl.pthc_import |> ofold (fun flag scope -> - match flag with - | `Import -> - { scope with sc_env = EcSection.import npath scope.sc_env; } - | `Export -> - let item = EcTheory.mkitem ~import:true (Th_export (npath, `Global)) in - { scope with sc_env = EcSection.add_item item scope.sc_env; } - | `Include -> scope) - scope - in - - if is_none thcl.pthc_local && oth.cth_loca = `Local then - notify scope `Info - "Theory `%s` has inherited `local` visibility. \ - Use the `global` keyword if this is not wanted." - name; - - Ax.add_defer scope proofs - -end - -(* -------------------------------------------------------------------- *) -module Ty = struct - open EcDecl - open EcTyping - - module TT = EcTyping - module ELI = EcInductive - module EHI = EcHiInductive - - (* ------------------------------------------------------------------ *) - let check_name_available scope x = - let pname = EcPath.pqname (EcEnv.root (env scope)) x.pl_desc in - - if EcEnv.Ty .by_path_opt pname (env scope) <> None - || EcEnv.TypeClass.by_path_opt pname (env scope) <> None then - hierror ~loc:x.pl_loc "duplicated type/type-class name `%s'" x.pl_desc - - (* ------------------------------------------------------------------ *) - let bind ?(import = true) (scope : scope) ((x, tydecl) : (_ * tydecl)) = - assert (scope.sc_pr_uc = None); - let item = EcTheory.mkitem ~import (EcTheory.Th_type (x, tydecl)) in - { scope with - sc_env = EcSection.add_item item scope.sc_env; - sc_locdoc = DocState.add_item scope.sc_locdoc; } - - (* ------------------------------------------------------------------ *) - let add ?(src : string option) scope (tyd : ptydecl located) = - let utyd = unloc tyd in - let scope = - { scope with - sc_locdoc = - match utyd.pty_locality with - | `Local -> DocState.prevent_process scope.sc_locdoc - | `Global -> DocState.start_process scope.sc_locdoc (unloc utyd.pty_name) `Type `Specific - | `Declare -> DocState.start_process scope.sc_locdoc (unloc utyd.pty_name) `Type `Abstract } - in - let scope = - { scope with - sc_locdoc = - match src with - | Some src -> DocState.push_srcbl scope.sc_locdoc src - | None -> scope.sc_locdoc; } - in - - let loc = loc tyd in - - let { pty_name = name; pty_tyvars = args; - pty_body = body; pty_locality = tyd_loca } = unloc tyd in - - check_name_available scope name; - let env = env scope in - let tyd_params, tyd_type = - match body with - | PTYD_Abstract -> - let ue = TT.transtyvars env (loc, Some args) in - EcUnify.UniEnv.tparams ue, Abstract - - | PTYD_Alias bd -> - let ue = TT.transtyvars env (loc, Some args) in - let body = transty tp_tydecl env ue bd in - EcUnify.UniEnv.tparams ue, Concrete body - - | PTYD_Datatype dt -> ( - let datatype = EHI.trans_datatype env (mk_loc loc (args, name)) dt in - let ty_from_ctor ctor = EcEnv.Ty.by_path ctor env in - try - ELI.check_positivity ty_from_ctor datatype; - let tparams, tydt = ELI.datatype_as_ty_dtype datatype in - (tparams, Datatype tydt) - with ELI.NonPositive ctx -> - let symbol = basename datatype.dt_path in - EHI.dterror loc env (EHI.DTE_NonPositive (symbol, ctx))) - - | PTYD_Record rt -> - let record = EHI.trans_record env (mk_loc loc (args,name)) rt in - let scheme = ELI.indsc_of_record record in - record.ELI.rc_tparams, Record (scheme, record.ELI.rc_fields) - in - - bind scope (unloc name, { tyd_params; tyd_type; tyd_loca; }) - - (* ------------------------------------------------------------------ *) - let add_subtype (scope : scope) ({ pl_desc = subtype } : psubtype located) = - let loced x = mk_loc _dummy x in - let env = env scope in - - let scope = - let decl = EcDecl.{ - tyd_params = []; - tyd_type = Abstract; - tyd_loca = `Global; (* FIXME:SUBTYPE *) - } in bind scope (unloc subtype.pst_name, decl) in - - let carrier = - let ue = EcUnify.UniEnv.create None in - transty tp_tydecl env ue subtype.pst_carrier in - - let pred = - let x = EcIdent.create (fst subtype.pst_pred).pl_desc in - let env = EcEnv.Var.bind_local x carrier env in - let ue = EcUnify.UniEnv.create None in - let pred = EcTyping.trans_prop env ue (snd subtype.pst_pred) in - if not (EcUnify.UniEnv.closed ue) then - hierror ~loc:(snd subtype.pst_pred).pl_loc - "the predicate contains free type variables"; - let uidmap = EcUnify.UniEnv.close ue in - let fs = Tuni.subst uidmap in - f_lambda [(x, GTty carrier)] (Fsubst.f_subst fs pred) in - - let evclone = - { EcThCloning.evc_empty with - evc_types = Msym.of_list [ - "T", loced (`Direct carrier, `Inline `Clear); - "sT", loced ( - `ByPath (EcPath.pqname (EcEnv.root env) (unloc subtype.pst_name)), - `Inline `Clear - ); - ]; - evc_ops = Msym.of_list [ - "P", loced (`Direct pred, `Inline `Clear) - ]; - evc_lemmas = { - ev_bynames = Msym.empty; - ev_global = [ (None, Some [`Include, "prove"]) ] - } } in - - let cname = Option.map unloc subtype.pst_cname in - let npath = ofold ((^~) EcPath.pqname) (EcEnv.root env) cname in - let cpath = EcPath.fromqsymbol ([EcCoreLib.i_top], "Subtype") in - let theory = EcEnv.Theory.by_path ~mode:`Abstract cpath env in - - let renames = - match subtype.pst_rename with - | None -> [] - - | Some (insub, val_) -> [ - (`All, (EcRegexp.regexp "val", EcRegexp.subst val_)); - (`All, (EcRegexp.regexp "insub", EcRegexp.subst insub)); - ] in - - let (proofs, scope) = - EcTheoryReplay.replay (Cloning.hooks ~override_locality:None) - ~abstract:false ~override_locality:None ~incl:(Option.is_none cname) - ~clears:Sp.empty ~renames ~opath:cpath ~npath - evclone scope - ( - Option.value ~default:(EcPath.basename cpath) cname, - false, - theory.cth_items, - theory.cth_loca - ) in - - let proofs = Cloning.replay_proofs scope `Check proofs in - - Ax.add_defer scope proofs - - (* ------------------------------------------------------------------ *) - let check_tci_operators env tcty ops reqs = - let ue = EcUnify.UniEnv.create (Some (fst tcty)) in - let rmap = Mstr.of_list reqs in + assert (scope.sc_pr_uc = None); + { scope with + sc_env = EcSection.enter_section (omap unloc name) scope.sc_env } - let ops = - let tt1 m (x, (tvi, op)) = - if not (Mstr.mem (unloc x) rmap) then - hierror ~loc:x.pl_loc "invalid operator name: `%s'" (unloc x); + let exit (scope : scope) (name : psymbol option) = + let sc_env = EcSection.exit_section (omap unloc name) scope.sc_env in + { scope with sc_env } +end - let tvi = List.map (TT.transty tp_tydecl env ue) tvi in - let selected = - EcUnify.select_op ~filter:(fun _ -> EcDecl.is_oper) - (Some (EcUnify.TVIunamed tvi)) env (unloc op) ue ([], None) - in - let op = - match selected with - | [] -> hierror ~loc:op.pl_loc "unknown operator" - | op1::op2::_ -> - hierror ~loc:op.pl_loc - "ambiguous operator (%s / %s)" - (EcPath.tostring (fst (proj4_1 op1))) - (EcPath.tostring (fst (proj4_1 op2))) - | [((p, _), _, _, _)] -> - let op = EcEnv.Op.by_path p env in - let opty = - Tvar.subst - (Tvar.init op.op_tparams tvi) - op.op_ty - in - (p, opty) +(* -------------------------------------------------------------------- *) +module Reduction = struct + (* FIXME: section -> allow "local" flag *) + let add_reduction scope (opts, reds) = + check_state `InTop "hint simplify" scope; - in - Mstr.change - (function - | None -> Some (x.pl_loc, op) - | Some _ -> hierror ~loc:(x.pl_loc) - "duplicated operator name: `%s'" (unloc x)) - (unloc x) m - in - List.fold_left tt1 Mstr.empty ops - in - List.iter - (fun (x, (req, _)) -> - if req && not (Mstr.mem x ops) then - hierror "no definition for operator `%s'" x) - reqs; - List.fold_left - (fun m (x, (_, ty)) -> - match Mstr.find_opt x ops with - | None -> m - | Some (loc, (p, opty)) -> - if not (EcReduction.EqTest.for_type env ty opty) then - hierror ~loc "invalid type for operator `%s'" x; - Mstr.add x p m) - Mstr.empty reqs + let rules = + let for1 idx name = + let idx = odfl 0 idx in + let ax_p = EcEnv.Ax.lookup_path (unloc name) (env scope) in + let opts = EcTheory.{ + ur_delta = List.mem `Delta opts; + ur_eqtrue = List.mem `EqTrue opts; + } in - (* ------------------------------------------------------------------ *) - let check_tci_axioms scope mode axs reqs lc = - let rmap = Mstr.of_list reqs in - let symbs, axs = - List.map_fold - (fun m (x, t) -> - if not (Mstr.mem (unloc x) rmap) then - hierror ~loc:x.pl_loc "invalid axiom name: `%s'" (unloc x); - if Sstr.mem (unloc x) m then - hierror ~loc:(x.pl_loc) "duplicated axiom name: `%s'" (unloc x); - (Sstr.add (unloc x) m, (unloc x, t, Mstr.find (unloc x) rmap))) - Sstr.empty axs in + let red_info = + EcReduction.User.compile ~opts ~prio:idx (env scope) ax_p in + (ax_p, opts, Some red_info) in - let interactive = - List.pmap - (fun (x, req) -> - if not (Mstr.mem x symbs) then - let ax = { - ax_tparams = []; - ax_spec = req; - ax_kind = `Lemma; - ax_loca = lc; - ax_smt = false; - } in Some ((None, ax), EcPath.psymbol x, scope.sc_env) - else None) - reqs in - List.iter - (fun (x, pt, f) -> - let x = "$" ^ x in - let t = { pt_core = pt; pt_intros = []; } in - let t = { pl_loc = pt.pl_loc; pl_desc = Pby (Some [t]) } in - let t = { pt_core = t; pt_intros = []; } in - let ax = { - ax_tparams = []; - ax_spec = f; - ax_kind = `Lemma; - ax_smt = false; - ax_loca = lc; - } in + let rules = List.map (fun (xs, idx) -> List.map (for1 idx) xs) reds in + List.flatten rules - let pucflags = { puc_smt = true; puc_local = false; } in - let pucflags = (([], None), pucflags) in - let check = Check_mode.check scope.sc_options in + in - let escope = scope in - let escope = Ax.start_lemma escope pucflags check ~name:x (ax, None) in - let escope = Tactics.proof escope in - let escope = snd (Tactics.process_r ~reloc:x false mode escope [t]) in - ignore (Ax.save_r escope)) - axs; - interactive + let item = EcTheory.mkitem ~import:true (EcTheory.Th_reduction rules) in + { scope with sc_env = EcSection.add_item item scope.sc_env } +end +(* -------------------------------------------------------------------- *) +module Cloning = struct (* ------------------------------------------------------------------ *) - (* FIXME section: those path does not exists ... - futhermode Ring.ZModule is an abstract theory *) - let p_zmod = EcPath.fromqsymbol ([EcCoreLib.i_top; "Ring"; "ZModule"], "zmodule") - let p_ring = EcPath.fromqsymbol ([EcCoreLib.i_top; "Ring"; "ComRing"], "ring" ) - let p_idomain = EcPath.fromqsymbol ([EcCoreLib.i_top; "Ring"; "IDomain"], "idomain") - let p_field = EcPath.fromqsymbol ([EcCoreLib.i_top; "Ring"; "Field" ], "field" ) + open EcTheory + open EcThCloning + + module C = EcThCloning + module R = EcTheoryReplay (* ------------------------------------------------------------------ *) - let ring_of_symmap env ty kind symbols = - { r_type = ty; - r_zero = oget (Mstr.find_opt "rzero" symbols); - r_one = oget (Mstr.find_opt "rone" symbols); - r_add = oget (Mstr.find_opt "add" symbols); - r_opp = (Mstr.find_opt "opp" symbols); - r_mul = oget (Mstr.find_opt "mul" symbols); - r_exp = (Mstr.find_opt "expr" symbols); - r_sub = (Mstr.find_opt "sub" symbols); - r_kind = kind; - r_embed = - (match Mstr.find_opt "ofint" symbols with - | None when EcReduction.EqTest.for_type env ty tint -> `Direct - | None -> `Default | Some p -> `Embed p); } + let hooks ~(override_locality : is_local option) : scope R.ovrhooks = + let thexit sc ~import pempty = snd (Theory.exit ~import ?clears:None ~pempty sc) in + let add_item scope ~import item = + let item = EcTheory.mkitem ~import item in + { scope with + sc_env = EcSection.add_item ~override_locality item scope.sc_env } in + { R.henv = (fun scope -> scope.sc_env); + R.hadd_item = add_item; + R.hthenter = Theory.enter; + R.hthexit = thexit; + R.herr = (fun ?loc -> hierror ?loc "%s"); } - let addring ~import (scope : scope) mode (kind, { pl_desc = tci; pl_loc = loc }) = - let env = env scope in - if not (EcAlgTactic.is_module_loaded env) then - hierror "load AlgTactic/Ring first"; + let () = subtype_hooks_ref := hooks ~override_locality:None - let ty = - let ue = TT.transtyvars env (loc, Some (fst tci.pti_type)) in - let ty = transty tp_tydecl env ue (snd tci.pti_type) in - assert (EcUnify.UniEnv.closed ue); - let uidmap = EcUnify.UniEnv.close ue in - (EcUnify.UniEnv.tparams ue, ty_subst (Tuni.subst uidmap) ty) - in - if not (List.is_empty (fst ty)) then - hierror "ring instances cannot be polymorphic"; + (* ------------------------------------------------------------------ *) + module Options = struct + open EcTheoryReplay - let symbols = EcAlgTactic.ring_symbols env kind (snd ty) in - let symbols = check_tci_operators env ty tci.pti_ops symbols in - let cr = ring_of_symmap env (snd ty) kind symbols in - let axioms = EcAlgTactic.ring_axioms env cr in - let lc = (tci.pti_loca :> locality) in - let inter = check_tci_axioms scope mode tci.pti_axs axioms lc in - let add env p = - let item = EcTheory.Th_instance (ty,`General p, tci.pti_loca) in - let item = EcTheory.mkitem ~import item in - EcSection.add_item item env in + let default = { clo_abstract = false; } - let scope = - { scope with sc_env = - List.fold_left add - (let item = - EcTheory.Th_instance (([], snd ty), `Ring cr, tci.pti_loca) in - let item = EcTheory.mkitem ~import item in - EcSection.add_item item scope.sc_env) - [p_zmod; p_ring; p_idomain] } + let merge1 opts (b, (x : theory_cloning_option)) = + match x with + | `Abstract -> { opts with clo_abstract = b; } - in Ax.add_defer scope inter + let merge opts (specs : theory_cloning_options) = + List.fold_left merge1 opts specs + end (* ------------------------------------------------------------------ *) - let field_of_symmap env ty symbols = - { f_ring = ring_of_symmap env ty `Integer symbols; - f_inv = oget (Mstr.find_opt "inv" symbols); - f_div = Mstr.find_opt "div" symbols; } + let clone (scope : scope) mode (thcl : theory_cloning) = + assert (scope.sc_pr_uc = None); - let addfield ~import (scope : scope) mode { pl_desc = tci; pl_loc = loc; } = - let env = env scope in - if not (EcAlgTactic.is_module_loaded env) then - hierror "load AlgTactic/Ring first"; + let { cl_name = name; + cl_theory = (opath, oth); + cl_clone = ovrds; + cl_rename = rnms; + cl_ntclr = ntclr; } - let ty = - let ue = TT.transtyvars env (loc, Some (fst tci.pti_type)) in - let ty = transty tp_tydecl env ue (snd tci.pti_type) in - assert (EcUnify.UniEnv.closed ue); - let uidmap = EcUnify.UniEnv.close ue in - (EcUnify.UniEnv.tparams ue, ty_subst (Tuni.subst uidmap) ty) - in - if not (List.is_empty (fst ty)) then - hierror "field instances cannot be polymorphic"; - let symbols = EcAlgTactic.field_symbols env (snd ty) in - let symbols = check_tci_operators env ty tci.pti_ops symbols in - let cr = field_of_symmap env (snd ty) symbols in - let axioms = EcAlgTactic.field_axioms env cr in - let lc = (tci.pti_loca :> locality) in - let inter = check_tci_axioms scope mode tci.pti_axs axioms lc; in - let add env p = - let item = EcTheory.Th_instance(ty,`General p, tci.pti_loca) in - let item = EcTheory.mkitem ~import item in - EcSection.add_item item env in - let scope = - { scope with - sc_env = - List.fold_left add - (let item = - EcTheory.Th_instance (([], snd ty), `Field cr, tci.pti_loca) in - let item = EcTheory.mkitem ~import item in - EcSection.add_item item scope.sc_env) - [p_zmod; p_ring; p_idomain; p_field] } + = C.clone scope.sc_env thcl in - in Ax.add_defer scope inter + let incl = thcl.pthc_import = Some `Include in + let opts = Options.merge Options.default thcl.pthc_opts in - (* ------------------------------------------------------------------ *) - let symbols_of_tc (_env : EcEnv.env) ty (tcp, tc) = - let subst = EcSubst.add_tydef EcSubst.empty tcp ([], ty) in - List.map (fun (x, opty) -> - (EcIdent.name x, (true, EcSubst.subst_ty subst opty))) - tc.tc_ops + if thcl.pthc_import = Some `Include && opts.R.clo_abstract then + hierror "cannot include an abstract theory"; + if thcl.pthc_import = Some `Include && EcUtils.is_some thcl.pthc_name then + hierror "cannot give an alias to an included clone"; -(* - (* ------------------------------------------------------------------ *) - let add_generic_tc (scope : scope) _mode { pl_desc = tci; pl_loc = loc; } = - let ty = - let ue = TT.transtyvars scope.sc_env (loc, Some (fst tci.pti_type)) in - let ty = transty tp_tydecl scope.sc_env ue (snd tci.pti_type) in - assert (EcUnify.UniEnv.closed ue); - (EcUnify.UniEnv.tparams ue, Tuni.offun (EcUnify.UniEnv.close ue) ty) - in + let cpath = EcEnv.root (env scope) in + let npath = if incl then cpath else EcPath.pqname cpath name in - let (tcp, tc) = - match EcEnv.TypeClass.lookup_opt (unloc tci.pti_name) (env scope) with - | None -> - hierror ~loc:tci.pti_name.pl_loc - "unknown type-class: %s" (string_of_qsymbol (unloc tci.pti_name)) - | Some tc -> tc + let (proofs, scope) = + EcTheoryReplay.replay (hooks ~override_locality:thcl.pthc_local) + ~abstract:opts.R.clo_abstract ~override_locality:thcl.pthc_local ~incl + ~clears:ntclr ~renames:rnms ~opath ~npath ovrds + scope (name, oth.cth_items, oth.cth_loca) in - let symbols = symbols_of_tc scope.sc_env (snd ty) (tcp, tc) in - let _symbols = check_tci_operators scope.sc_env ty tci.pti_ops symbols in + let proofs = List.pmap (fun axc -> + match axc.C.axc_tac with + | None -> + Some (fst_map some axc.C.axc_axiom, axc.C.axc_path, axc.C.axc_env) - { scope with - sc_env = EcEnv.TypeClass.add_instance ty (`General tcp) scope.sc_env } + | Some pt -> + let t = { pt_core = pt; pt_intros = []; } in + let t = { pl_loc = pt.pl_loc; pl_desc = Pby (Some [t]); } in + let t = { pt_core = t; pt_intros = []; } in + let (x, ax) = axc.C.axc_axiom in -(* - let ue = EcUnify.UniEnv.create (Some []) in - let ty = fst (EcUnify.UniEnv.openty ue (fst ty) None (snd ty)) in - try EcUnify.hastc scope.sc_env ue ty (Sp.singleton (fst tc)); tc - with EcUnify.UnificationFailure _ -> - hierror "type must be an instance of `%s'" (EcPath.tostring (fst tc)) -*) -*) + let pucflags = { puc_smt = true; puc_local = false; } in + let pucflags = (([], None), pucflags) in + let check = Check_mode.check scope.sc_options in - (* ------------------------------------------------------------------ *) - let add_instance (scope : scope) mode ({ pl_desc = tci } as toptci) = - match unloc tci.pti_name with - | ([], "bring") -> begin - if EcUtils.is_some tci.pti_args then - hierror "unsupported-option"; - addring ~import:true scope mode (`Boolean, toptci) - end + let escope = { scope with sc_env = axc.C.axc_env; } in + let escope = Ax.start_lemma escope pucflags check ~name:x (ax, None) in + let escope = Tactics.proof escope in + let escope = snd (Tactics.process_r ~reloc:x false mode escope [t]) in + ignore (Ax.save_r escope); None) + proofs + in - | ([], "ring") -> begin - let kind = - match tci.pti_args with - | None -> `Integer - | Some (`Ring (c, p)) -> - if odfl false (c |> omap (fun c -> c <^ BI.of_int 2)) then - hierror "invalid coefficient modulus"; - if odfl false (p |> omap (fun p -> p <^ BI.of_int 2)) then - hierror "invalid power modulus"; - if opt_equal BI.equal c (Some (BI.of_int 2)) - && opt_equal BI.equal p (Some (BI.of_int 2)) - then `Boolean - else `Modulus (c, p) - in addring ~import:true scope mode (kind, toptci) - end + let scope = + thcl.pthc_import |> ofold (fun flag scope -> + match flag with + | `Import -> + { scope with sc_env = EcSection.import npath scope.sc_env; } + | `Export -> + let item = EcTheory.mkitem ~import:true (Th_export (npath, `Global)) in + { scope with sc_env = EcSection.add_item item scope.sc_env; } + | `Include -> scope) + scope - | ([], "field") -> addfield ~import:true scope mode toptci + in Ax.add_defer scope proofs - | _ -> - if EcUtils.is_some tci.pti_args then - hierror "unsupported-option"; - failwith "unsupported" (* FIXME *) end -(* -------------------------------------------------------------------- *)module Search = struct +(* -------------------------------------------------------------------- *) +module Search = struct let search (scope : scope) qs = let env = env scope in let paths = @@ -2680,15 +3011,15 @@ end let ps = ref Mid.empty in let ue = EcUnify.UniEnv.create None in let tip = EcUnify.UniEnv.opentvi ue decl.op_tparams None in - let tip = f_subst_init ~tv:tip () in - let es = e_subst tip in + let tip = f_subst_init ~tv:(Mid.map fst tip.subst) () in + let es = e_subst tip in let xs = List.map (snd_map (ty_subst tip)) nt.ont_args in - let bd = EcFol.form_of_expr (es nt.ont_body) in - List.iter (fun (id, ty) -> ps := Mid.add id ty !ps) xs; + let bd = EcFol.form_of_expr ~m:(EcIdent.create "&hr") (es nt.ont_body) in + let fp = EcFol.f_lambda (List.map (snd_map EcFol.gtty) xs) bd in - match bd.f_node with + match fp.f_node with | Fop (pf, _) -> (pf :: paths, pts) - | _ -> (paths, (ps, ue, bd) ::pts) + | _ -> (paths, (ps, ue, fp) :: pts) end | _ -> (p :: paths, pts) in @@ -2724,14 +3055,29 @@ end notify scope `Info "%s" (Buffer.contents buffer) let locate (scope : scope) ({ pl_desc = name } : pqsymbol) = - let ppe = EcPrinting.PPEnv.ofenv (env scope) in - let shorten lk p = - let lk (p : path) (qs : qsymbol) = - match lk qs (env scope) with - | Some (p', _) -> p_equal p p' - | _ -> false in - EcPrinting.shorten_path ppe lk p + let rec doit prefix (nm, x) = + match lk (nm, x) (env scope) with + | Some (p', _) when EcPath.p_equal p p' -> + (nm, x) + | _ -> begin + match prefix with + | [] -> (nm, x) + | n :: prefix -> doit prefix (n :: nm, x) + end + in + + let (nm, x) = EcPath.toqsymbol p in + let nm = + match nm with + | top :: nm when top = EcCoreLib.i_top -> + nm + | _ -> nm in + + let nm', x' = doit (List.rev nm) ([], x) in + let plong, pshort = (nm, x), (nm', x') in + + (plong, if plong = pshort then None else Some pshort) in let buffer = Buffer.create 0 in @@ -2771,13 +3117,9 @@ end notify scope `Info "%s" (Buffer.contents buffer) end + (* -------------------------------------------------------------------- *) module DocComment = struct - let add (scope : scope) ((kind, docc) : [`Global | `Item] * string) : scope = - match kind with - | `Global -> - { scope with sc_globdoc = scope.sc_globdoc @ [docc] } - - | `Item -> - { scope with sc_locdoc = DocState.push_docbl scope.sc_locdoc docc } + let add (scope : scope) ((_, _) : [`Global | `Item] * string) : scope = + scope end diff --git a/src/ecScope.mli b/src/ecScope.mli index f3548a3bd1..e11cc21ff0 100644 --- a/src/ecScope.mli +++ b/src/ecScope.mli @@ -138,9 +138,9 @@ end (* -------------------------------------------------------------------- *) module Ty : sig val add : ?src:string -> scope -> ptydecl located -> scope - - val add_subtype : scope -> psubtype located -> scope - val add_instance : scope -> Ax.proofmode -> ptycinstance located -> scope + val add_class : scope -> ptypeclass located -> scope + val add_subtype : scope -> psubtype located -> scope + val add_instance : ?import:bool -> scope -> Ax.proofmode -> ptycinstance located -> scope end (* -------------------------------------------------------------------- *) diff --git a/src/ecSection.ml b/src/ecSection.ml index 828baaf9e6..0e1fe748cf 100644 --- a/src/ecSection.ml +++ b/src/ecSection.ml @@ -1,6 +1,7 @@ (* -------------------------------------------------------------------- *) open EcUtils open EcSymbols +open EcMaps open EcPath open EcAst open EcTypes @@ -21,7 +22,7 @@ type cbarg = [ | `Module of mpath | `ModuleType of path | `Typeclass of path - | `Instance of tcinstance + | `TcInstance of [`General of path | `Ring | `Field] ] type cb = cbarg -> unit @@ -37,35 +38,27 @@ let pp_cbarg env fmt (who : cbarg) = let ppe = EcPrinting.PPEnv.ofenv env in match who with | `Type p -> Format.fprintf fmt "type %a" (EcPrinting.pp_tyname ppe) p - | `Op p -> - begin - let op = EcEnv.Op.by_path p env in - match op.op_kind with - | OB_oper (Some (OP_Exn _)) -> - Format.fprintf fmt "exception %a" (EcPrinting.pp_opname ppe) p - | _ -> Format.fprintf fmt "operator %a" (EcPrinting.pp_opname ppe) p - end + | `Op p -> Format.fprintf fmt "operator %a" (EcPrinting.pp_opname ppe) p | `Ax p -> Format.fprintf fmt "lemma/axiom %a" (EcPrinting.pp_axname ppe) p | `Module mp -> let ppe = match mp.m_top with - | `Local id -> - if EcEnv.Mod.is_declared id env then - ppe - else EcPrinting.PPEnv.add_locals ppe [id] + | `Local id -> EcPrinting.PPEnv.add_locals ppe [id] | _ -> ppe in Format.fprintf fmt "module %a" (EcPrinting.pp_topmod ppe) mp | `ModuleType p -> - let mty = EcEnv.ModTy.modtype p env in - Format.fprintf fmt "module type %a" (EcPrinting.pp_modtype1 ppe) mty + Format.fprintf fmt "module type %a" + (EcPrinting.pp_modtype1 ppe) + (EcEnv.ModTy.modtype p env) | `Typeclass p -> - Format.fprintf fmt "typeclass %a" (EcPrinting.pp_tcname ppe) p - | `Instance tci -> - match tci with - | `Ring _ -> Format.fprintf fmt "ring instance" - | `Field _ -> Format.fprintf fmt "field instance" - | `General _ -> Format.fprintf fmt "instance" - + Format.fprintf fmt "typeclass %a" (EcPrinting.pp_tyname ppe) p + | `TcInstance (`General p) -> + Format.fprintf fmt "typeclass instance %a" (EcPrinting.pp_axname ppe) p + | `TcInstance `Ring -> + Format.fprintf fmt "ring instance" + | `TcInstance `Field -> + Format.fprintf fmt "field instance" + let pp_locality fmt = function | `Local -> Format.fprintf fmt "local" | `Global -> () @@ -101,393 +94,350 @@ let hierror fmt = bfmt fmt (* -------------------------------------------------------------------- *) -type aenv = { - env : EcEnv.env; (* Global environment for dep. analysis *) - cb : cb; (* Dep. analysis callback *) - cache : acache ref; (* Dep. analysis cache *) -} - -and acache = { - op : Sp.t; (* Operator declaration already handled *) - type_ : Sp.t; (* Type declaration already handled *) -} - -(* -------------------------------------------------------------------- *) -let empty_acache : acache = - { op = Sp.empty; type_ = Sp.empty; } +let rec on_mp (cb : cb) (mp : mpath) = + let f = m_functor mp in + cb (`Module f); + List.iter (on_mp cb) mp.m_args -(* -------------------------------------------------------------------- *) -let mkaenv (env : EcEnv.env) (cb : cb) : aenv = - { env; cb; cache = ref empty_acache; } - -(* -------------------------------------------------------------------- *) -let rec on_mp (aenv : aenv) (mp : mpath) = - aenv.cb (`Module (m_functor mp)); - List.iter (on_mp aenv) mp.m_args - -(* -------------------------------------------------------------------- *) -and on_xp (aenv : aenv) (xp : xpath) = - on_mp aenv xp.x_top - -(* -------------------------------------------------------------------- *) -and on_memtype (aenv : aenv) (mt : EcMemory.memtype) = - EcMemory.mt_iter_ty (on_ty aenv) mt +let on_xp (cb : cb) (xp : xpath) = + on_mp cb xp.x_top -(* -------------------------------------------------------------------- *) -and on_memenv (aenv : aenv) (m : EcMemory.memenv) = - on_memtype aenv (snd m) - -(* -------------------------------------------------------------------- *) -and on_pv (aenv : aenv) (pv : prog_var)= +let rec on_ty (cb : cb) (ty : ty) = + match ty.ty_node with + | Tunivar _ -> () + | Tvar _ -> () + | Tglob _ -> () + | Ttuple tys -> List.iter (on_ty cb) tys + | Tconstr (p, tys) -> cb (`Type p); List.iter (on_etyarg cb) tys + | Tfun (ty1, ty2) -> List.iter (on_ty cb) [ty1; ty2] + +and on_etyarg cb ((ty, tcw) : etyarg) = + on_ty cb ty; + List.iter (on_tcwitness cb) tcw + +and on_tcwitness cb (tcw : tcwitness) = + match tcw with + | TCIUni _ -> + () + + | TCIConcrete { path; etyargs } -> + List.iter (on_etyarg cb) etyargs; + cb (`TcInstance (`General path)) + + | TCIAbstract { support = `Abs path } -> + cb (`Type path) + + | TCIAbstract { support = `Var _ } -> + () + +let on_pv (cb : cb) (pv : prog_var)= match pv with - | PVglob xp -> on_xp aenv xp + | PVglob xp -> on_xp cb xp | _ -> () -(* -------------------------------------------------------------------- *) -and on_lp (aenv : aenv) (lp : lpattern) = +let on_lp (cb : cb) (lp : lpattern) = match lp with - | LSymbol (_, ty) -> on_ty aenv ty - | LTuple xs -> List.iter (fun (_, ty) -> on_ty aenv ty) xs - | LRecord (_, xs) -> List.iter (on_ty aenv -| snd) xs - -(* -------------------------------------------------------------------- *) -and on_binding (aenv : aenv) ((_, ty) : (EcIdent.t * ty)) = - on_ty aenv ty - -(* -------------------------------------------------------------------- *) -and on_bindings (aenv : aenv) (bds : (EcIdent.t * ty) list) = - List.iter (on_binding aenv) bds - -(* -------------------------------------------------------------------- *) -and on_ty (aenv : aenv) (ty : ty) = - match ty.ty_node with - | Tunivar _ -> () - | Tvar _ -> () - | Tglob m -> aenv.cb (`Module (mident m)) - | Ttuple tys -> List.iter (on_ty aenv) tys - | Tconstr (p, tys) -> on_tyname aenv p; List.iter (on_ty aenv) tys - | Tfun (ty1, ty2) -> List.iter (on_ty aenv) [ty1; ty2] + | LSymbol (_, ty) -> on_ty cb ty + | LTuple xs -> List.iter (fun (_, ty) -> on_ty cb ty) xs + | LRecord (_, xs) -> List.iter (snd |- on_ty cb) xs -(* -------------------------------------------------------------------- *) -and on_tyname (aenv : aenv) (p : path) = - aenv.cb (`Type p); - if not (Sp.mem p !(aenv.cache).type_) then begin - let cache = { !(aenv.cache) with type_ = Sp.add p !(aenv.cache).type_ } in - aenv.cache := cache; - on_tydecl aenv (EcEnv.Ty.by_path p aenv.env) - end +let on_binding (cb : cb) ((_, ty) : (EcIdent.t * ty)) = + on_ty cb ty -(* -------------------------------------------------------------------- *) -and on_opname (aenv : aenv) (p : EcPath.path) = - aenv.cb (`Op p); - if not (Sp.mem p !(aenv.cache).op) then begin - let cache = { !(aenv.cache) with op = Sp.add p !(aenv.cache).op } in - aenv.cache := cache; - on_opdecl aenv (EcEnv.Op.by_path p aenv.env); - end +let on_bindings (cb : cb) (bds : (EcIdent.t * ty) list) = + List.iter (on_binding cb) bds -(* -------------------------------------------------------------------- *) -and on_expr (aenv : aenv) (e : expr) = - let cbrec = on_expr aenv in +let rec on_expr (cb : cb) (e : expr) = + let cbrec = on_expr cb in let fornode () = match e.e_node with | Eint _ -> () | Elocal _ -> () - | Equant (_, bds, e) -> on_bindings aenv bds; cbrec e - | Evar pv -> on_pv aenv pv - | Elet (lp, e1, e2) -> on_lp aenv lp; List.iter cbrec [e1; e2] + | Equant (_, bds, e) -> on_bindings cb bds; cbrec e + | Evar pv -> on_pv cb pv + | Elet (lp, e1, e2) -> on_lp cb lp; List.iter cbrec [e1; e2] | Etuple es -> List.iter cbrec es + | Eop (p, tys) -> cb (`Op p); List.iter (on_etyarg cb) tys | Eapp (e, es) -> List.iter cbrec (e :: es) | Eif (c, e1, e2) -> List.iter cbrec [c; e1; e2] - | Ematch (e, es, ty) -> on_ty aenv ty; List.iter cbrec (e :: es) + | Ematch (e, es, ty) -> on_ty cb ty; List.iter cbrec (e :: es) | Eproj (e, _) -> cbrec e - | Eop (p, tys) -> begin - on_opname aenv p; - List.iter (on_ty aenv) tys; - end - - in on_ty aenv e.e_ty; fornode () + in on_ty cb e.e_ty; fornode () -(* -------------------------------------------------------------------- *) -and on_lv (aenv : aenv) (lv : lvalue) = - let for1 (pv, ty) = on_pv aenv pv; on_ty aenv ty in +let on_lv (cb : cb) (lv : lvalue) = + let for1 (pv, ty) = on_pv cb pv; on_ty cb ty in match lv with | LvVar pv -> for1 pv | LvTuple pvs -> List.iter for1 pvs -(* -------------------------------------------------------------------- *) -and on_instr (aenv : aenv) (i : instr)= +let rec on_instr (cb : cb) (i : instr)= match i.i_node with | Srnd (lv, e) | Sasgn (lv, e) -> - on_lv aenv lv; - on_expr aenv e + on_lv cb lv; + on_expr cb e | Sraise e -> - on_expr aenv e + on_expr cb e | Scall (lv, f, args) -> - oiter (on_lv aenv) lv; - on_xp aenv f; - List.iter (on_expr aenv) args + lv |> oiter (on_lv cb); + on_xp cb f; + List.iter (on_expr cb) args | Sif (e, s1, s2) -> - on_expr aenv e; - List.iter (on_stmt aenv) [s1; s2] + on_expr cb e; + List.iter (on_stmt cb) [s1; s2] | Swhile (e, s) -> - on_expr aenv e; - on_stmt aenv s + on_expr cb e; + on_stmt cb s | Smatch (e, b) -> let forb (bs, s) = - List.iter (on_ty aenv -| snd) bs; - on_stmt aenv s - in on_expr aenv e; List.iter forb b + List.iter (snd |- on_ty cb) bs; + on_stmt cb s + in on_expr cb e; List.iter forb b | Sabstract _ -> () -(* -------------------------------------------------------------------- *) -and on_stmt (aenv : aenv) (s : stmt) = - List.iter (on_instr aenv) s.s_node +and on_stmt (cb : cb) (s : stmt) = + List.iter (on_instr cb) s.s_node -(* -------------------------------------------------------------------- *) -and on_form (aenv : aenv) (f : EcFol.form) = - let cbrec = on_form aenv in +let on_memtype cb mt = + EcMemory.mt_iter_ty (on_ty cb) mt + +let on_memenv cb (m : EcMemory.memenv) = + on_memtype cb (snd m) + +let rec on_form (cb : cb) (f : EcFol.form) = + let cbrec = on_form cb in let rec fornode () = match f.EcAst.f_node with | EcAst.Fint _ -> () | EcAst.Flocal _ -> () - | EcAst.Fquant (_, b, f) -> on_gbindings aenv b; cbrec f + | EcAst.Fquant (_, b, f) -> on_gbindings cb b; cbrec f | EcAst.Fif (f1, f2, f3) -> List.iter cbrec [f1; f2; f3] - | EcAst.Fmatch (b, fs, ty) -> on_ty aenv ty; List.iter cbrec (b :: fs) - | EcAst.Flet (lp, f1, f2) -> on_lp aenv lp; List.iter cbrec [f1; f2] + | EcAst.Fmatch (b, fs, ty) -> on_ty cb ty; List.iter cbrec (b :: fs) + | EcAst.Flet (lp, f1, f2) -> on_lp cb lp; List.iter cbrec [f1; f2] + | EcAst.Fop (p, tys) -> cb (`Op p); List.iter (on_etyarg cb) tys | EcAst.Fapp (f, fs) -> List.iter cbrec (f :: fs) | EcAst.Ftuple fs -> List.iter cbrec fs | EcAst.Fproj (f, _) -> cbrec f - | EcAst.Fpvar (pv, _) -> on_pv aenv pv + | EcAst.Fpvar (pv, _) -> on_pv cb pv | EcAst.Fglob _ -> () - | EcAst.FhoareF hf -> on_hf aenv hf - | EcAst.FhoareS hs -> on_hs aenv hs - | EcAst.FeHoareF hf -> on_ehf aenv hf - | EcAst.FeHoareS hs -> on_ehs aenv hs - | EcAst.FequivF ef -> on_ef aenv ef - | EcAst.FequivS es -> on_es aenv es - | EcAst.FeagerF eg -> on_eg aenv eg - | EcAst.FbdHoareS bhs -> on_bhs aenv bhs - | EcAst.FbdHoareF bhf -> on_bhf aenv bhf - | EcAst.Fpr pr -> on_pr aenv pr - | EcAst.Fop (p, tys) -> begin - on_opname aenv p; - List.iter (on_ty aenv) tys; - end - - and on_hf (aenv : aenv) hf = - on_form aenv (hf_pr hf).inv; - POE.iter (on_form aenv) (hf_po hf).hsi_inv; - on_xp aenv hf.EcAst.hf_f - - and on_hs (aenv : aenv) hs = - on_form aenv (hs_pr hs).inv; - POE.iter (on_form aenv) (hs_po hs).hsi_inv; - on_stmt aenv hs.EcAst.hs_s; - on_memenv aenv hs.EcAst.hs_m - - and on_ef (aenv : aenv) ef = - on_form aenv (EcAst.ef_pr ef).inv; - on_form aenv (EcAst.ef_po ef).inv; - on_xp aenv ef.EcAst.ef_fl; - on_xp aenv ef.EcAst.ef_fr - - and on_es (aenv : aenv) es = - on_form aenv (EcAst.es_pr es).inv; - on_form aenv (EcAst.es_po es).inv; - on_stmt aenv es.EcAst.es_sl; - on_stmt aenv es.EcAst.es_sr; - on_memenv aenv es.EcAst.es_ml; - on_memenv aenv es.EcAst.es_mr - - and on_eg (aenv : aenv) eg = - on_form aenv (EcAst.eg_pr eg).inv; - on_form aenv (EcAst.eg_po eg).inv; - on_xp aenv eg.EcAst.eg_fl; - on_xp aenv eg.EcAst.eg_fr; - on_stmt aenv eg.EcAst.eg_sl; - on_stmt aenv eg.EcAst.eg_sr; - - and on_ehf (aenv : aenv) hf = - on_form aenv (EcAst.ehf_pr hf).inv; - on_form aenv (EcAst.ehf_po hf).inv; - on_xp aenv hf.EcAst.ehf_f - - and on_ehs (aenv : aenv) hs = - on_form aenv (EcAst.ehs_pr hs).inv; - on_form aenv (EcAst.ehs_po hs).inv; - on_stmt aenv hs.EcAst.ehs_s; - on_memenv aenv hs.EcAst.ehs_m - - and on_bhf (aenv : aenv) bhf = - on_form aenv (EcAst.bhf_pr bhf).inv; - on_form aenv (EcAst.bhf_po bhf).inv; - on_form aenv (EcAst.bhf_bd bhf).inv; - on_xp aenv bhf.EcAst.bhf_f - - and on_bhs (aenv : aenv) bhs = - on_form aenv (EcAst.bhs_pr bhs).inv; - on_form aenv (EcAst.bhs_po bhs).inv; - on_form aenv (EcAst.bhs_bd bhs).inv; - on_stmt aenv bhs.EcAst.bhs_s; - on_memenv aenv bhs.EcAst.bhs_m - - - and on_pr (aenv : aenv) pr = - on_xp aenv pr.EcAst.pr_fun; - List.iter (on_form aenv) [pr.EcAst.pr_event.inv; pr.EcAst.pr_args] + | EcAst.FhoareF hf -> on_hf cb hf + | EcAst.FhoareS hs -> on_hs cb hs + | EcAst.FeHoareF hf -> on_ehf cb hf + | EcAst.FeHoareS hs -> on_ehs cb hs + | EcAst.FequivF ef -> on_ef cb ef + | EcAst.FequivS es -> on_es cb es + | EcAst.FeagerF eg -> on_eg cb eg + | EcAst.FbdHoareS bhs -> on_bhs cb bhs + | EcAst.FbdHoareF bhf -> on_bhf cb bhf + | EcAst.Fpr pr -> on_pr cb pr + + and on_hf cb hf = + on_form cb (EcAst.hf_pr hf).inv; + on_form cb (EcAst.hf_po hf).hsi_inv.main; + on_xp cb hf.EcAst.hf_f + + and on_hs cb hs = + on_form cb (EcAst.hs_pr hs).inv; + on_form cb (EcAst.hs_po hs).hsi_inv.main; + on_stmt cb hs.EcAst.hs_s; + on_memenv cb hs.EcAst.hs_m + + and on_ef cb ef = + on_form cb (EcAst.ef_pr ef).inv; + on_form cb (EcAst.ef_po ef).inv; + on_xp cb ef.EcAst.ef_fl; + on_xp cb ef.EcAst.ef_fr + + and on_es cb es = + on_form cb (EcAst.es_pr es).inv; + on_form cb (EcAst.es_po es).inv; + on_stmt cb es.EcAst.es_sl; + on_stmt cb es.EcAst.es_sr; + on_memenv cb es.EcAst.es_ml; + on_memenv cb es.EcAst.es_mr + + and on_eg cb eg = + on_form cb (EcAst.eg_pr eg).inv; + on_form cb (EcAst.eg_po eg).inv; + on_xp cb eg.EcAst.eg_fl; + on_xp cb eg.EcAst.eg_fr; + on_stmt cb eg.EcAst.eg_sl; + on_stmt cb eg.EcAst.eg_sr; + + and on_ehf cb hf = + on_form cb (EcAst.ehf_pr hf).inv; + on_form cb (EcAst.ehf_po hf).inv; + on_xp cb hf.EcAst.ehf_f + + and on_ehs cb hs = + on_form cb (EcAst.ehs_pr hs).inv; + on_form cb (EcAst.ehs_po hs).inv; + on_stmt cb hs.EcAst.ehs_s; + on_memenv cb hs.EcAst.ehs_m + + and on_bhf cb bhf = + on_form cb (EcAst.bhf_pr bhf).inv; + on_form cb (EcAst.bhf_po bhf).inv; + on_form cb (EcAst.bhf_bd bhf).inv; + on_xp cb bhf.EcAst.bhf_f + + and on_bhs cb bhs = + on_form cb (EcAst.bhs_pr bhs).inv; + on_form cb (EcAst.bhs_po bhs).inv; + on_form cb (EcAst.bhs_bd bhs).inv; + on_stmt cb bhs.EcAst.bhs_s; + on_memenv cb bhs.EcAst.bhs_m + + + and on_pr cb pr = + on_xp cb pr.EcAst.pr_fun; + on_form cb pr.EcAst.pr_event.inv; + on_form cb pr.EcAst.pr_args in - on_ty aenv f.EcAst.f_ty; fornode () + on_ty cb f.EcAst.f_ty; fornode () -(* -------------------------------------------------------------------- *) -and on_restr (aenv : aenv) (restr : mod_restr) = - let doit (xs, ms) = Sx.iter (on_xp aenv) xs; Sm.iter (on_mp aenv) ms in +and on_restr (cb : cb) (restr : mod_restr) = + let doit (xs, ms) = Sx.iter (on_xp cb) xs; Sm.iter (on_mp cb) ms in oiter doit restr.ur_pos; doit restr.ur_neg -(* -------------------------------------------------------------------- *) -and on_modty (aenv : aenv) (mty : module_type) = - aenv.cb (`ModuleType mty.mt_name); - List.iter (fun (_, mty) -> on_modty aenv mty) mty.mt_params; - List.iter (on_mp aenv) mty.mt_args +and on_modty cb (mty : module_type) = + cb (`ModuleType mty.mt_name); + List.iter (fun (_, mty) -> on_modty cb mty) mty.mt_params; + List.iter (on_mp cb) mty.mt_args -(* -------------------------------------------------------------------- *) -and on_mty_mr (aenv : aenv) ((mty, mr) : mty_mr) = - on_modty aenv mty; on_restr aenv mr +and on_mty_mr (cb : cb) ((mty, mr) : mty_mr) = + on_modty cb mty; on_restr cb mr -(* -------------------------------------------------------------------- *) -and on_gbinding (aenv : aenv) (b : gty) = +and on_gbinding (cb : cb) (b : gty) = match b with | EcAst.GTty ty -> - on_ty aenv ty + on_ty cb ty | EcAst.GTmodty mty -> - on_mty_mr aenv mty + on_mty_mr cb mty | EcAst.GTmem m -> - on_memtype aenv m + on_memtype cb m -(* -------------------------------------------------------------------- *) -and on_gbindings (aenv : aenv) (b : (EcIdent.t * gty) list) = - List.iter (fun (_, b) -> on_gbinding aenv b) b +and on_gbindings (cb : cb) (b : (EcIdent.t * gty) list) = + List.iter (fun (_, b) -> on_gbinding cb b) b -(* -------------------------------------------------------------------- *) -and on_module (aenv : aenv) (me : module_expr) = +and on_module (cb : cb) (me : module_expr) = match me.me_body with - | ME_Alias (_, mp) -> on_mp aenv mp - | ME_Structure st -> on_mstruct aenv st - | ME_Decl mty -> on_mty_mr aenv mty + | ME_Alias (_, mp) -> on_mp cb mp + | ME_Structure st -> on_mstruct cb st + | ME_Decl mty -> on_mty_mr cb mty -(* -------------------------------------------------------------------- *) -and on_mstruct (aenv : aenv) (st : module_structure) = - List.iter (on_mstruct1 aenv) st.ms_body +and on_mstruct (cb : cb) (st : module_structure) = + List.iter (on_mpath_mstruct1 cb) st.ms_body -(* -------------------------------------------------------------------- *) -and on_mstruct1 (aenv : aenv) (item : module_item) = +and on_mpath_mstruct1 (cb : cb) (item : module_item) = match item with - | MI_Module me -> on_module aenv me - | MI_Variable x -> on_ty aenv x.v_type - | MI_Function f -> on_fun aenv f + | MI_Module me -> on_module cb me + | MI_Variable x -> on_ty cb x.v_type + | MI_Function f -> on_fun cb f -(* -------------------------------------------------------------------- *) -and on_fun (aenv : aenv) (fun_ : function_) = - on_fun_sig aenv fun_.f_sig; - on_fun_body aenv fun_.f_def +and on_fun (cb : cb) (fun_ : function_) = + on_fun_sig cb fun_.f_sig; + on_fun_body cb fun_.f_def -(* -------------------------------------------------------------------- *) -and on_fun_sig (aenv : aenv) (fsig : funsig) = - on_ty aenv fsig.fs_arg; - on_ty aenv fsig.fs_ret +and on_fun_sig (cb : cb) (fsig : funsig) = + on_ty cb fsig.fs_arg; + on_ty cb fsig.fs_ret -(* -------------------------------------------------------------------- *) -and on_fun_body (aenv : aenv) (fbody : function_body) = +and on_fun_body (cb : cb) (fbody : function_body) = match fbody with - | FBalias xp -> on_xp aenv xp - | FBdef fdef -> on_fun_def aenv fdef - | FBabs oi -> on_oi aenv oi + | FBalias xp -> on_xp cb xp + | FBdef fdef -> on_fun_def cb fdef + | FBabs oi -> on_oi cb oi -(* -------------------------------------------------------------------- *) -and on_fun_def (aenv : aenv) (fdef : function_def) = - List.iter (fun v -> on_ty aenv v.v_type) fdef.f_locals; - on_stmt aenv fdef.f_body; - fdef.f_ret |> oiter (on_expr aenv); - on_uses aenv fdef.f_uses +and on_fun_def (cb : cb) (fdef : function_def) = + List.iter (fun v -> on_ty cb v.v_type) fdef.f_locals; + on_stmt cb fdef.f_body; + fdef.f_ret |> oiter (on_expr cb); + on_uses cb fdef.f_uses -(* -------------------------------------------------------------------- *) -and on_uses (aenv : aenv) (uses : uses) = - List.iter (on_xp aenv) uses.us_calls; - Sx.iter (on_xp aenv) uses.us_reads; - Sx.iter (on_xp aenv) uses.us_writes +and on_uses (cb : cb) (uses : uses) = + List.iter (on_xp cb) uses.us_calls; + Sx.iter (on_xp cb) uses.us_reads; + Sx.iter (on_xp cb) uses.us_writes -(* -------------------------------------------------------------------- *) -and on_oi (aenv : aenv) (oi : OI.t) = - List.iter (on_xp aenv) (OI.allowed oi) +and on_oi (cb : cb) (oi : OI.t) = + List.iter (on_xp cb) (OI.allowed oi) (* -------------------------------------------------------------------- *) -and on_typarams (_aenv : aenv) (_typarams : ty_params) = - () +let on_typeclass cb tc = + cb (`Typeclass tc.tc_name); + List.iter (on_etyarg cb) tc.tc_args + +let on_typeclasses cb tcs = + List.iter (on_typeclass cb) tcs + +let on_typarams cb typarams = + List.iter (fun (_, tc) -> on_typeclasses cb tc) typarams (* -------------------------------------------------------------------- *) -and on_tydecl (aenv : aenv) (tyd : tydecl) = - on_typarams aenv tyd.tyd_params; +let on_tydecl (cb : cb) (tyd : tydecl) = + on_typarams cb tyd.tyd_params; match tyd.tyd_type with - | Concrete ty -> on_ty aenv ty - | Abstract -> () - | Record (f, fds) -> - on_form aenv f; - List.iter (on_ty aenv -| snd) fds - | Datatype dt -> - List.iter (List.iter (on_ty aenv) -| snd) dt.tydt_ctors; - List.iter (on_form aenv) [dt.tydt_schelim; dt.tydt_schcase] - -and on_typeclass (aenv : aenv) tc = - oiter (fun p -> aenv.cb (`Typeclass p)) tc.tc_prt; - List.iter (fun (_,ty) -> on_ty aenv ty) tc.tc_ops; - List.iter (fun (_,f) -> on_form aenv f) tc.tc_axs + | `Concrete ty -> on_ty cb ty + | `Abstract s -> on_typeclasses cb s + | `Record (f, fds) -> + on_form cb f; + List.iter (snd |- on_ty cb) fds + | `Datatype dt -> + List.iter (snd |- List.iter (on_ty cb)) dt.tydt_ctors; + List.iter (on_form cb) [dt.tydt_schelim; dt.tydt_schcase] + +let on_tcdecl cb tc = + List.iter (fun (p, _ren) -> on_typeclass cb p) tc.tc_prts; + List.iter (fun (_,ty) -> on_ty cb ty) tc.tc_ops; + List.iter (fun (_,f) -> on_form cb f) tc.tc_axs (* -------------------------------------------------------------------- *) -and on_opdecl (aenv : aenv) (opdecl : operator) = - on_typarams aenv opdecl.op_tparams; +let on_opdecl (cb : cb) (opdecl : operator) = + on_typarams cb opdecl.op_tparams; let for_kind () = match opdecl.op_kind with | OB_pred None -> () | OB_pred (Some (PR_Plain f)) -> - on_form aenv f + on_form cb f | OB_pred (Some (PR_Ind pri)) -> - on_bindings aenv pri.pri_args; + on_bindings cb pri.pri_args; List.iter (fun ctor -> - on_gbindings aenv ctor.prc_bds; - List.iter (on_form aenv) ctor.prc_spec) + on_gbindings cb ctor.prc_bds; + List.iter (on_form cb) ctor.prc_spec) pri.pri_ctors | OB_nott nott -> - List.iter (on_ty aenv -| snd) nott.ont_args; - on_ty aenv nott.ont_resty; - on_expr aenv nott.ont_body + List.iter (snd |- on_ty cb) nott.ont_args; + on_ty cb nott.ont_resty; + on_expr cb nott.ont_body | OB_oper None -> () | OB_oper Some b -> match b with - | OP_Exn ty -> List.iter (on_ty aenv) ty - | OP_Constr _ | OP_Record _ | OP_Proj _ | OP_TC -> () - | OP_Plain f -> on_form aenv f + | OP_Constr _ | OP_Record _ | OP_Proj _ -> assert false + | OP_TC _ -> assert false + | OP_Exn ty -> List.iter (on_ty cb) ty + | OP_Plain f -> on_form cb f | OP_Fix f -> let rec on_mpath_branches br = match br with | OPB_Leaf (bds, e) -> - List.iter (on_bindings aenv) bds; - on_expr aenv e + List.iter (on_bindings cb) bds; + on_expr cb e | OPB_Branch br -> Parray.iter on_mpath_branch br @@ -496,48 +446,49 @@ and on_opdecl (aenv : aenv) (opdecl : operator) = in on_mpath_branches f.opf_branches - in on_ty aenv opdecl.op_ty; for_kind () + in on_ty cb opdecl.op_ty; for_kind () (* -------------------------------------------------------------------- *) -and on_axiom (aenv : aenv) (ax : axiom) = - on_typarams aenv ax.ax_tparams; - on_form aenv ax.ax_spec +let on_axiom (cb : cb) (ax : axiom) = + on_typarams cb ax.ax_tparams; + on_form cb ax.ax_spec (* -------------------------------------------------------------------- *) -and on_modsig (aenv : aenv) (ms:module_sig) = - List.iter (fun (_,mt) -> on_modty aenv mt) ms.mis_params; +let on_modsig (cb:cb) (ms:module_sig) = + List.iter (fun (_,mt) -> on_modty cb mt) ms.mis_params; List.iter (fun (Tys_function fs) -> - on_ty aenv fs.fs_arg; - List.iter (fun x -> on_ty aenv x.ov_type) fs.fs_anames; - on_ty aenv fs.fs_ret;) ms.mis_body; - Msym.iter (fun _ oi -> on_oi aenv oi) ms.mis_oinfos - -(* -------------------------------------------------------------------- *) -and on_ring (aenv : aenv) (r : ring) = - on_ty aenv r.r_type; - let on_p p = on_opname aenv p in + on_ty cb fs.fs_arg; + List.iter (fun x -> on_ty cb x.ov_type) fs.fs_anames; + on_ty cb fs.fs_ret;) ms.mis_body; + Msym.iter (fun _ oi -> on_oi cb oi) ms.mis_oinfos + +let on_ring cb r = + on_ty cb r.r_type; + let on_p p = cb (`Op p) in List.iter on_p [r.r_zero; r.r_one; r.r_add; r.r_mul]; List.iter (oiter on_p) [r.r_opp; r.r_exp; r.r_sub]; match r.r_embed with | `Direct | `Default -> () | `Embed p -> on_p p -(* -------------------------------------------------------------------- *) -and on_field (aenv : aenv) (f : field) = - on_ring aenv f.f_ring; - let on_p p = on_opname aenv p in +let on_field cb f = + on_ring cb f.f_ring; + let on_p p = cb (`Op p) in on_p f.f_inv; oiter on_p f.f_div -(* -------------------------------------------------------------------- *) -and on_instance (aenv : aenv) ty tci = - on_typarams aenv (fst ty); - on_ty aenv (snd ty); - match tci with - | `Ring r -> on_ring aenv r - | `Field f -> on_field aenv f - | `General p -> - (* FIXME section: ring/field use type class that do not exists *) - aenv.cb (`Typeclass p) +let on_instance cb tci = + on_typarams cb tci.tci_params; + on_ty cb tci.tci_type; + (* FIXME section: ring/field use type class that do not exists *) + match tci.tci_instance with + | `Ring r -> on_ring cb r + | `Field f -> on_field cb f + + | `General (tci, syms) -> + on_typeclass cb tci; + Option.iter + (Mstr.iter (fun _ (p, tys) -> cb (`Op p); List.iter (on_etyarg cb) tys)) + syms (* -------------------------------------------------------------------- *) type sc_name = @@ -585,19 +536,23 @@ let pp_thname scenv = (* -------------------------------------------------------------------- *) let locality (env : EcEnv.env) (who : cbarg) = match who with - | `Type p -> (EcEnv.Ty.by_path p env).tyd_loca - | `Op p -> (EcEnv.Op.by_path p env).op_loca - | `Ax p -> (EcEnv.Ax.by_path p env).ax_loca - | `Typeclass p -> ((EcEnv.TypeClass.by_path p env).tc_loca :> locality) - | `Module mp -> begin - match EcEnv.Mod.by_mpath_opt mp env with + | `Type p -> begin + match EcEnv.TypeClass.by_path_opt p env with + | Some tc -> (tc.tc_loca :> locality) + | _ -> (EcEnv.Ty.by_path p env).tyd_loca + end + | `Op p -> (EcEnv. Op.by_path p env).op_loca + | `Ax p -> (EcEnv. Ax.by_path p env).ax_loca + | `Typeclass p -> ((EcEnv.TypeClass.by_path p env).tc_loca :> locality) + | `Module mp -> + begin match EcEnv.Mod.by_mpath_opt mp env with | Some (_, Some lc) -> lc - | _ -> - let id = EcPath.mget_ident mp in - if EcEnv.Mod.is_declared id env then `Declare else `Global + (* in this case it should be a quantified module *) + | _ -> `Global end | `ModuleType p -> ((EcEnv.ModTy.by_path p env).tms_loca :> locality) - | `Instance _ -> assert false + | `TcInstance (`General p) -> (EcEnv.TcInstance.by_path p env).tci_local + | `TcInstance (`Ring | `Field) -> `Global (* -------------------------------------------------------------------- *) type to_clear = @@ -654,11 +609,17 @@ let add_declared_mod to_gen id modty = let add_declared_ty to_gen path tydecl = assert (tydecl.tyd_params = []); - let name = "'" ^ basename path in - let id = EcIdent.create name in + let s = + match tydecl.tyd_type with + | `Abstract s -> s + | _ -> assert false in + + let name = Format.sprintf "'%s" (basename path) in + let id = EcIdent.create name in + { to_gen with - tg_params = to_gen.tg_params @ [id]; - tg_subst = EcSubst.add_tydef to_gen.tg_subst path ([], tvar id); + tg_params = to_gen.tg_params @ [id, s]; + tg_subst = EcSubst.add_tydef to_gen.tg_subst path ([], tvar id, []); } let add_declared_op to_gen path opdecl = @@ -680,14 +641,22 @@ let add_declared_op to_gen path opdecl = | _ -> assert false } let tvar_fv ty = Mid.map (fun () -> 1) (EcTypes.Tvar.fv ty) + + let etyargs_tvar_fv etyargs = + Mid.map (fun () -> 1) (EcTypes.etyargs_tvar_fv etyargs) + let fv_and_tvar_e e = let rec aux fv e = let fv = EcIdent.fv_union fv (tvar_fv e.e_ty) in match e.e_node with - | Eop(_, tys) -> List.fold_left (fun fv ty -> EcIdent.fv_union fv (tvar_fv ty)) fv tys + | Eop(_, etyargs) -> + EcIdent.fv_union fv (etyargs_tvar_fv etyargs) | Equant(_,d,e) -> - let fv = List.fold_left (fun fv (_,ty) -> EcIdent.fv_union fv (tvar_fv ty)) fv d in - aux fv e + let fv = + List.fold_left + (fun fv (_,ty) -> EcIdent.fv_union fv (tvar_fv ty)) + fv d + in aux fv e | _ -> e_fold aux fv e in aux e.e_fv e @@ -706,7 +675,8 @@ and fv_and_tvar_f f = let rec aux f = fv := EcIdent.fv_union !fv (tvar_fv f.f_ty); match f.f_node with - | Fop(_, tys) -> fv := List.fold_left (fun fv ty -> EcIdent.fv_union fv (tvar_fv ty)) !fv tys + | Fop(_, tys) -> + fv := EcIdent.fv_union !fv (etyargs_tvar_fv tys) | Fquant(_, d, f) -> fv := List.fold_left (fun fv (_,gty) -> EcIdent.fv_union fv (gty_fv_and_tvar gty)) !fv d; aux f @@ -720,23 +690,29 @@ and fv_and_tvar_f f = let tydecl_fv tyd = let fv = match tyd.tyd_type with - | Concrete ty -> ty_fv_and_tvar ty - | Abstract -> Mid.empty - | Datatype tydt -> + | `Concrete ty -> ty_fv_and_tvar ty + | `Abstract _ -> Mid.empty + | `Datatype tydt -> List.fold_left (fun fv (_, l) -> List.fold_left (fun fv ty -> EcIdent.fv_union fv (ty_fv_and_tvar ty)) fv l) Mid.empty tydt.tydt_ctors - | Record (_f, l) -> + | `Record (_f, l) -> List.fold_left (fun fv (_, ty) -> EcIdent.fv_union fv (ty_fv_and_tvar ty)) Mid.empty l in - List.fold_left (fun fv id -> Mid.remove id fv) fv tyd.tyd_params + let fv = + match tyd.tyd_subtype with + | None -> fv + | Some (carrier, pred) -> + EcIdent.fv_union fv + (EcIdent.fv_union (ty_fv_and_tvar carrier) (fv_and_tvar_f pred)) in + List.fold_left (fun fv (id, _) -> Mid.remove id fv) fv tyd.tyd_params let op_body_fv body ty = let fv = ty_fv_and_tvar ty in match body with | OP_Plain f -> EcIdent.fv_union fv (fv_and_tvar_f f) - | OP_Constr _ | OP_Record _ | OP_Proj _ | OP_TC | OP_Exn _ -> fv + | OP_Constr _ | OP_Record _ | OP_Proj _ | OP_TC _ | OP_Exn _ -> fv | OP_Fix opfix -> let fv = List.fold_left (fun fv (_, ty) -> EcIdent.fv_union fv (ty_fv_and_tvar ty)) @@ -775,7 +751,7 @@ let notation_fv nota = EcIdent.fv_union (Mid.remove id fv) (ty_fv_and_tvar ty)) fv nota.ont_args let generalize_extra_ty to_gen fv = - List.filter (fun id -> Mid.mem id fv) to_gen.tg_params + List.filter (fun (id,_) -> Mid.mem id fv) to_gen.tg_params let rec generalize_extra_args binds fv = match binds with @@ -811,14 +787,14 @@ let generalize_tydecl to_gen prefix (name, tydecl) = let fv = tydecl_fv tydecl in let extra = generalize_extra_ty to_gen fv in let tyd_params = extra @ tydecl.tyd_params in - let args = List.map tvar tyd_params in - let params = tydecl.tyd_params in - let tosubst = params, tconstr path args in + let args = List.map (fun (id, _) -> tvar id) tyd_params in + let fst_params = List.map fst tydecl.tyd_params in + let tosubst = (fst_params, tconstr path args, []) in let tg_subst, tyd_type = match tydecl.tyd_type with - | Concrete _ | Abstract -> + | `Concrete _ | `Abstract _ -> EcSubst.add_tydef to_gen.tg_subst path tosubst, tydecl.tyd_type - | Record (f, prs) -> + | `Record (f, prs) -> let subst = EcSubst.empty in let tg_subst = to_gen.tg_subst in let subst = EcSubst.add_tydef subst path tosubst in @@ -828,15 +804,15 @@ let generalize_tydecl to_gen prefix (name, tydecl) = let tin = tconstr path args in let add_op (s, ty) = let p = pqname prefix s in - let tosubst = params, e_op p args (tfun tin ty) in + let tosubst = fst_params, e_op p args (tfun tin ty) in rsubst := EcSubst.add_opdef !rsubst p tosubst; rtg_subst := EcSubst.add_opdef !rtg_subst p tosubst; s, ty in let prs = List.map add_op prs in let f = EcSubst.subst_form !rsubst f in - !rtg_subst, Record (f, prs) - | Datatype dt -> + !rtg_subst, `Record (f, prs) + | `Datatype dt -> let subst = EcSubst.empty in let tg_subst = to_gen.tg_subst in let subst = EcSubst.add_tydef subst path tosubst in @@ -849,21 +825,23 @@ let generalize_tydecl to_gen prefix (name, tydecl) = let tys = List.map subst_ty tys in let p = pqname prefix s in let pty = toarrow tys tout in - let tosubst = params, e_op p args pty in + let tosubst = fst_params, e_op p args pty in rsubst := EcSubst.add_opdef !rsubst p tosubst; rtg_subst := EcSubst.add_opdef !rtg_subst p tosubst ; s, tys in let tydt_ctors = List.map add_op dt.tydt_ctors in let tydt_schelim = EcSubst.subst_form !rsubst dt.tydt_schelim in let tydt_schcase = EcSubst.subst_form !rsubst dt.tydt_schcase in - !rtg_subst, Datatype {tydt_ctors; tydt_schelim; tydt_schcase } + !rtg_subst, `Datatype {tydt_ctors; tydt_schelim; tydt_schcase } in let to_gen = { to_gen with tg_subst} in let tydecl = { tyd_params; tyd_type; - tyd_loca = `Global; } in + tyd_loca = `Global; + tyd_resolve = tydecl.tyd_resolve; + tyd_subtype = tydecl.tyd_subtype; } in to_gen, Some (Th_type (name, tydecl)) | `Declare -> @@ -883,8 +861,9 @@ let generalize_opdecl to_gen prefix (name, operator) = let extra = generalize_extra_ty to_gen fv in let tparams = extra @ operator.op_tparams in let opty = operator.op_ty in - let args = List.map tvar tparams in - let tosubst = (operator.op_tparams, e_op path args opty) in + let args = List.map (fun (id, _) -> tvar id) tparams in + let tosubst = (List.map fst operator.op_tparams, + e_op path args opty) in let tg_subst = EcSubst.add_opdef to_gen.tg_subst path tosubst in tg_subst, mk_op ~opaque:operator.op_opaque tparams opty None `Global @@ -894,8 +873,9 @@ let generalize_opdecl to_gen prefix (name, operator) = let extra = generalize_extra_ty to_gen fv in let tparams = extra @ operator.op_tparams in let opty = operator.op_ty in - let args = List.map tvar tparams in - let tosubst = (operator.op_tparams, f_op path args opty) in + let etyargs = EcDecl.etyargs_of_tparams tparams in + let tosubst = (List.map fst operator.op_tparams, + f_op_tc path etyargs opty) in let tg_subst = EcSubst.add_pddef to_gen.tg_subst path tosubst in tg_subst, mk_op ~opaque:operator.op_opaque tparams opty None `Global @@ -906,18 +886,20 @@ let generalize_opdecl to_gen prefix (name, operator) = let tparams = extra_t @ operator.op_tparams in let extra_a = generalize_extra_args to_gen.tg_binds fv in let opty = toarrow (List.map snd extra_a) operator.op_ty in - let t_args = List.map tvar tparams in - let eop = e_op path t_args opty in + let etyargs = EcDecl.etyargs_of_tparams tparams in + let eop = e_op_tc path etyargs opty in let e = e_app eop (List.map (fun (id,ty) -> e_local id ty) extra_a) operator.op_ty in - let tosubst = (operator.op_tparams, e) in + let tosubst = + (List.map fst operator.op_tparams, e) in let tg_subst = EcSubst.add_opdef to_gen.tg_subst path tosubst in let body = match body with - | OP_Constr _ | OP_Record _ | OP_Proj _ | OP_Exn _ -> assert false - | OP_TC -> assert false (* ??? *) + | OP_Constr _ | OP_Record _ | OP_Proj _ -> assert false + | OP_TC _ -> assert false (* ??? *) + | OP_Exn _ -> assert false | OP_Plain f -> OP_Plain (f_lambda (List.map (fun (x, ty) -> (x, GTty ty)) extra_a) f) | OP_Fix opfix -> @@ -944,12 +926,13 @@ let generalize_opdecl to_gen prefix (name, operator) = let op_tparams = extra_t @ operator.op_tparams in let extra_a = generalize_extra_args to_gen.tg_binds fv in let op_ty = toarrow (List.map snd extra_a) operator.op_ty in - let t_args = List.map tvar op_tparams in - let fop = f_op path t_args op_ty in + let etyargs = EcDecl.etyargs_of_tparams op_tparams in + let fop = f_op_tc path etyargs op_ty in let f = f_app fop (List.map (fun (id,ty) -> f_local id ty) extra_a) operator.op_ty in - let tosubst = (operator.op_tparams, f) in + let tosubst = + (List.map fst operator.op_tparams, f) in let tg_subst = EcSubst.add_pddef to_gen.tg_subst path tosubst in let body = @@ -1046,7 +1029,7 @@ let generalize_module to_gen prefix me = | _ -> () in try - on_mp (mkaenv to_gen.tg_env.sc_env check_gen) mp; + on_mp check_gen mp; to_gen, Some (Th_module me) with Inline -> @@ -1068,11 +1051,11 @@ let generalize_export to_gen (p,lc) = if lc = `Local || to_clear to_gen (`Th p) then to_gen, None else to_gen, Some (Th_export (p,lc)) -let generalize_instance to_gen (ty,tci, lc) = - if lc = `Local then to_gen, None - (* FIXME: be sure that we have no dep to declare or local, - or fix this code *) - else to_gen, Some (Th_instance (ty,tci,lc)) +let generalize_instance to_gen (x, tci) = + if tci.tci_local = `Local then to_gen, None + else + let tci = EcSubst.subst_tcinstance to_gen.tg_subst tci in + to_gen, Some (Th_instance (x, tci)) let generalize_baserw to_gen prefix (s,lc) = if lc = `Local then @@ -1089,20 +1072,21 @@ let generalize_addrw to_gen (p, ps, lc) = let generalize_reduction to_gen _rl = to_gen, None -let generalize_auto to_gen auto_rl = - if auto_rl.locality = `Local then - to_gen, None +let generalize_auto to_gen { level=n; base=s; axioms=ps; locality=lc } = + if lc = `Local then to_gen, None else - let axioms = - List.filter (fun (p, _) -> to_keep to_gen (`Ax p)) auto_rl.axioms in - if List.is_empty axioms then - to_gen, None - else - to_gen, Some (Th_auto {auto_rl with axioms}) + let ps = List.filter (fun (p, _) -> to_keep to_gen (`Ax p)) ps in + if ps = [] then to_gen, None + else to_gen, Some (Th_auto { level=n; base=s; axioms=ps; locality=lc }) (* --------------------------------------------------------------- *) let get_locality scenv = scenv.sc_loca +let set_local l = + match l with + | `Global -> `Local + | _ -> l + let id_lc = function | `Global -> `Global | `Local -> `Local @@ -1119,44 +1103,61 @@ let rec set_lc_item lc_override item = | Th_axiom (s,ax) -> Th_axiom (s, { ax with ax_loca = set_lc lc_override ax.ax_loca }) | Th_modtype (s,ms) -> Th_modtype (s, { ms with tms_loca = set_lc lc_override ms.tms_loca }) | Th_module me -> Th_module { me with tme_loca = set_lc lc_override me.tme_loca } - | Th_theory (s, th) -> Th_theory (s, set_local_th lc_override th) + | Th_typeclass (s,tc) -> Th_typeclass (s, { tc with tc_loca = set_lc lc_override tc.tc_loca }) + | Th_theory (s, th) -> Th_theory (s, set_lc_th lc_override th) | Th_export (p,lc) -> Th_export (p, set_lc lc_override lc) - | Th_instance (ty,ti,lc) -> Th_instance (ty,ti, set_lc lc_override lc) + | Th_instance (x,tci) -> Th_instance (x, { tci with tci_local = set_lc lc_override tci.tci_local }) | Th_baserw (s,lc) -> Th_baserw (s, set_lc lc_override lc) | Th_addrw (p,ps,lc) -> Th_addrw (p, ps, set_lc lc_override lc) | Th_reduction r -> Th_reduction r - | Th_auto auto_rl -> Th_auto {auto_rl with locality=set_lc lc_override auto_rl.locality} - | Th_alias alias -> Th_alias alias + | Th_auto ar -> Th_auto { ar with locality = set_lc lc_override ar.locality } + | Th_alias a -> Th_alias a in { item with ti_item = lcitem } -and set_local_th lc_override th = +and set_lc_th lc_override th = { th with cth_items = List.map (set_lc_item lc_override) th.cth_items; cth_loca = set_lc lc_override th.cth_loca; } +let rec set_local_item item = + let lcitem = + match item.ti_item with + | Th_type (s,ty) -> Th_type (s, { ty with tyd_loca = set_local ty.tyd_loca }) + | Th_operator (s,op) -> Th_operator (s, { op with op_loca = set_local op.op_loca }) + | Th_axiom (s,ax) -> Th_axiom (s, { ax with ax_loca = set_local ax.ax_loca }) + | Th_modtype (s,ms) -> Th_modtype (s, { ms with tms_loca = set_local ms.tms_loca }) + | Th_module me -> Th_module { me with tme_loca = set_local me.tme_loca } + | Th_typeclass (s,tc) -> Th_typeclass (s, { tc with tc_loca = set_local tc.tc_loca }) + | Th_theory (s, th) -> Th_theory (s, set_local_th th) + | Th_export (p,lc) -> Th_export (p, set_local lc) + | Th_instance (x,tci) -> Th_instance (x, { tci with tci_local = set_local tci.tci_local }) + | Th_baserw (s,lc) -> Th_baserw (s, set_local lc) + | Th_addrw (p,ps,lc) -> Th_addrw (p, ps, set_local lc) + | Th_reduction r -> Th_reduction r + | Th_auto ar -> Th_auto { ar with locality = set_local ar.locality } + | Th_alias a -> Th_alias a + + in { item with ti_item = lcitem } + +and set_local_th th = + { th with cth_items = List.map set_local_item th.cth_items; + cth_loca = set_local th.cth_loca; } + +let sc_th_item t item = + let item = + match get_locality t with + | `Global -> item + | `Local -> set_local_item item in + SC_th_item item + let sc_decl_mod (id,mt) = SC_decl_mod (id,mt) + (* ---------------------------------------------------------------- *) let is_abstract_ty = function - | Abstract -> true - | _ -> false - -(* -let rec check_glob_mp_ty s scenv mp = - let mtop = `Module (mastrip mp) in - if is_declared scenv mtop then - hierror "global %s can't depend on declared module" s; - if is_local scenv mtop then - hierror "global %s can't depend on local module" s; - List.iter (check_glob_mp_ty s scenv) mp.m_args - -let rec check_glob_mp scenv mp = - let mtop = `Module (mastrip mp) in - if is_local scenv mtop then - hierror "global definition can't depend on local module"; - List.iter (check_glob_mp scenv) mp.m_args - *) + | `Abstract _ -> true + | _ -> false let check s scenv who b = if not b then @@ -1171,24 +1172,26 @@ let check_polymorph scenv who typarams = let check_abstract = check "should be abstract" type can_depend = { - d_ty : locality list; - d_op : locality list; - d_ax : locality list; - d_sc : locality list; - d_mod : locality list; - d_modty : locality list; - d_tc : locality list; - } + d_ty : locality list; + d_op : locality list; + d_ax : locality list; + d_sc : locality list; + d_mod : locality list; + d_modty : locality list; + d_tc : locality list; + d_tci : locality list; +} -let cd_glob = - { d_ty = [`Global]; - d_op = [`Global]; - d_ax = [`Global]; - d_sc = [`Global]; - d_mod = [`Global]; - d_modty = [`Global]; - d_tc = [`Global]; - } +let cd_glob = { + d_ty = [`Global]; + d_op = [`Global]; + d_ax = [`Global]; + d_sc = [`Global]; + d_mod = [`Global]; + d_modty = [`Global]; + d_tc = [`Global]; + d_tci = [`Global]; +} let can_depend (cd : can_depend) = function | `Type _ -> cd.d_ty @@ -1198,8 +1201,7 @@ let can_depend (cd : can_depend) = function | `Module _ -> cd.d_mod | `ModuleType _ -> cd.d_modty | `Typeclass _ -> cd.d_tc - | `Instance _ -> assert false - + | `TcInstance _ -> cd.d_tci let cb scenv from cd who = let env = scenv.sc_env in @@ -1230,8 +1232,9 @@ let check_tyd scenv prefix name tyd = d_mod = [`Global]; d_modty = []; d_tc = [`Global]; + d_tci = [`Global]; } in - on_tydecl (mkaenv scenv.sc_env (cb scenv from cd)) tyd + on_tydecl (cb scenv from cd) tyd let is_abstract_op op = match op.op_kind with @@ -1256,8 +1259,9 @@ let check_op scenv prefix name op = d_mod = [`Declare; `Global]; d_modty = []; d_tc = [`Global]; + d_tci = [`Global]; } in - on_opdecl (mkaenv scenv.sc_env (cb scenv from cd)) op + on_opdecl (cb scenv from cd) op | `Global -> let cd = { @@ -1268,8 +1272,9 @@ let check_op scenv prefix name op = d_mod = [`Global]; d_modty = []; d_tc = [`Global]; + d_tci = [`Global]; } in - on_opdecl (mkaenv scenv.sc_env (cb scenv from cd)) op + on_opdecl (cb scenv from cd) op let is_inth scenv = match scenv.sc_name with @@ -1287,8 +1292,9 @@ let check_ax (scenv : scenv) (prefix : path) (name : symbol) (ax : axiom) = d_mod = [`Declare; `Global]; d_modty = [`Global]; d_tc = [`Global]; + d_tci = [`Global]; } in - let doit = on_axiom (mkaenv scenv.sc_env (cb scenv from cd)) in + let doit = on_axiom (cb scenv from cd) in let error b s1 s = if b then hierror "%s %a %s" s1 (pp_axname scenv) path s in @@ -1320,7 +1326,7 @@ let check_modtype scenv prefix name ms = | `Local -> check_section scenv from | `Global -> if scenv.sc_insec then - on_modsig (mkaenv scenv.sc_env (cb scenv from cd_glob)) ms.tms_sig + on_modsig (cb scenv from cd_glob) ms.tms_sig let check_module scenv prefix tme = @@ -1330,40 +1336,48 @@ let check_module scenv prefix tme = match tme.tme_loca with | `Local -> check_section scenv from | `Global -> - if scenv.sc_insec then begin - let isalias = EcModules.is_me_body_alias tme.tme_expr.me_body in + if scenv.sc_insec then let cd = { d_ty = [`Global]; d_op = [`Global]; d_ax = []; d_sc = []; - d_mod = [`Global] @ (if isalias then [`Declare] else []); + d_mod = [`Global]; (* FIXME section: add local *) d_modty = [`Global]; d_tc = [`Global]; + d_tci = [`Global]; } in - on_module (mkaenv scenv.sc_env (cb scenv from cd)) me - end + on_module (cb scenv from cd) me | `Declare -> (* Should be SC_decl_mod ... *) assert false -let check_typeclass scenv prefix name tc = +let check_tcdecl scenv prefix name tc = let path = pqname prefix name in let from = ((tc.tc_loca :> locality), `Typeclass path) in if tc.tc_loca = `Local then check_section scenv from else - on_typeclass (mkaenv scenv.sc_env (cb scenv from cd_glob)) tc - -let check_instance scenv ty tci lc = - let from = (lc :> locality), `Instance tci in - if lc = `Local then check_section scenv from + on_tcdecl (cb scenv from cd_glob) tc + +let check_instance scenv prefix x tci = + let from = + match x, tci.tci_instance with + | Some x, `General _ -> `General (pqname prefix x) + | None , `Ring _ -> `Ring + | None , `Field _ -> `Field + | _ , _ -> assert false in + + let from = (tci.tci_local, `TcInstance from) in + + if tci.tci_local = `Local then check_section scenv from else if scenv.sc_insec then - match tci with + match tci.tci_instance with | `Ring _ | `Field _ -> - on_instance (mkaenv scenv.sc_env (cb scenv from cd_glob) )ty tci + on_instance (cb scenv from cd_glob) tci + | `General _ -> - let cd = { cd_glob with d_ty = [`Declare; `Global]; } in - on_instance (mkaenv scenv.sc_env (cb scenv from cd)) ty tci + let cd = { cd_glob with d_ty = [`Declare; `Global]; } in + on_instance (cb scenv from cd) tci (* -----------------------------------------------------------*) let enter_theory (name:symbol) (lc:is_local) (mode:thmode) scenv : scenv = @@ -1396,25 +1410,25 @@ let add_item_ ?(override_locality=None) (item : theory_item) (scenv:scenv) = let import = item.ti_import in let env = match item.ti_item with - | Th_type (s,tyd) -> EcEnv.Ty.bind ~import s tyd env - | Th_operator (s,op) -> EcEnv.Op.bind ~import s op env - | Th_axiom (s, ax) -> EcEnv.Ax.bind ~import s ax env - | Th_modtype (s, ms) -> EcEnv.ModTy.bind ~import s ms env - | Th_module me -> EcEnv.Mod.bind ~import me.tme_expr.me_name me env - | Th_export (p, lc) -> EcEnv.Theory.export p lc env - | Th_instance (tys,i,lc) -> EcEnv.TypeClass.add_instance ~import tys i lc env (*FIXME: import? *) - | Th_baserw (s,lc) -> EcEnv.BaseRw.add ~import s lc env - | Th_addrw (p,ps,lc) -> EcEnv.BaseRw.addto ~import p ps lc env - | Th_auto auto -> EcEnv.Auto.add - ~import ~level:auto.level ?base:auto.base - auto.axioms auto.locality env - | Th_alias (n,p) -> EcEnv.Theory.alias ~import n p env - | Th_reduction r -> EcEnv.Reduction.add ~import r env - | _ -> assert false + | Th_type (s,tyd) -> EcEnv.Ty.bind ~import s tyd env + | Th_operator (s,op) -> EcEnv.Op.bind ~import s op env + | Th_axiom (s, ax) -> EcEnv.Ax.bind ~import s ax env + | Th_modtype (s, ms) -> EcEnv.ModTy.bind ~import s ms env + | Th_module me -> EcEnv.Mod.bind ~import me.tme_expr.me_name me env + | Th_typeclass(s,tc) -> EcEnv.TypeClass.bind ~import s tc env + | Th_export (p, lc) -> EcEnv.Theory.export p lc env + | Th_instance (x, tc) -> EcEnv.TcInstance.bind ~import x tc env + | Th_baserw (s,lc) -> EcEnv.BaseRw.add ~import s lc env + | Th_addrw (p,ps,lc) -> EcEnv.BaseRw.addto ~import p ps lc env + | Th_auto { level; base; axioms = ps; locality = lc } -> + EcEnv.Auto.add ~import ~level ?base ps lc env + | Th_reduction r -> EcEnv.Reduction.add ~import r env + | Th_alias (n, p) -> EcEnv.Theory.alias ~import n p env + | _ -> assert false in - (item, { scenv with + { scenv with sc_env = env; - sc_items = SC_th_item item :: scenv.sc_items}) + sc_items = SC_th_item item :: scenv.sc_items} let add_th ~import (cth : EcEnv.Theory.compiled_theory) scenv = let env = EcEnv.Theory.bind ~import cth scenv.sc_env in @@ -1431,19 +1445,22 @@ let rec generalize_th_item (to_gen : to_gen) (prefix : path) (th_item : theory_i | Th_module me -> generalize_module to_gen prefix me | Th_theory th -> (generalize_ctheory to_gen prefix th, None) | Th_export (p,lc) -> generalize_export to_gen (p,lc) - | Th_instance (ty,i,lc) -> generalize_instance to_gen (ty,i,lc) + | Th_instance (x,tci)-> generalize_instance to_gen (x,tci) + | Th_typeclass (x, tc) -> + if tc.tc_loca = `Local then to_gen, None + else to_gen, Some (Th_typeclass (x, tc)) | Th_baserw (s,lc) -> generalize_baserw to_gen prefix (s,lc) | Th_addrw (p,ps,lc) -> generalize_addrw to_gen (p, ps, lc) | Th_reduction rl -> generalize_reduction to_gen rl | Th_auto hints -> generalize_auto to_gen hints - | Th_alias _ -> (to_gen, None) (* FIXME:ALIAS *) + | Th_alias _ -> (to_gen, None) in let scenv = item |> Option.fold ~none:to_gen.tg_env ~some:(fun item -> let item = { ti_import = th_item.ti_import; ti_item = item; } in - add_item_ item to_gen.tg_env |> snd + add_item_ item to_gen.tg_env ) in @@ -1460,12 +1477,25 @@ and generalize_ctheory if cth.cth_mode = `Abstract && cth.cth_loca = `Local then add_clear genenv (`Th path) else - let scenv = enter_theory name `Global cth.cth_mode genenv.tg_env in - let genenv_tmp = List.fold_left - (fun x -> generalize_th_item x path) - { genenv with tg_env = scenv } cth.cth_items in + let compiled = + let genenv = + let scenv = + enter_theory + name `Global cth.cth_mode + genenv.tg_env + in + { genenv with tg_env = scenv } + in + + let genenv = + List.fold_left (fun genenv item -> + generalize_th_item genenv path item + ) genenv cth.cth_items in + + let _, compiled, _ = exit_theory genenv.tg_env in - let _, compiled, _ = exit_theory genenv_tmp.tg_env in + compiled + in match compiled with | None -> @@ -1496,7 +1526,7 @@ let genenv_of_scenv (scenv : scenv) : to_gen = ; tg_params = [] ; tg_binds = [] ; tg_subst = EcSubst.empty - ; tg_clear = empty_locals } + ; tg_clear = empty_locals } let generalize_lc_items scenv = let togen = @@ -1505,7 +1535,7 @@ let generalize_lc_items scenv = (EcEnv.root scenv.sc_env) (List.rev scenv.sc_items) in togen.tg_env - + (* -----------------------------------------------------------*) let import p scenv = { scenv with sc_env = EcEnv.Theory.import p scenv.sc_env } @@ -1532,23 +1562,28 @@ let check_item scenv item = | Th_axiom (s, ax) -> check_ax scenv prefix s ax | Th_modtype (s, ms) -> check_modtype scenv prefix s ms | Th_module me -> check_module scenv prefix me + | Th_typeclass (s,tc) -> check_tcdecl scenv prefix s tc | Th_export (_, lc) -> assert (lc = `Global || scenv.sc_insec); - | Th_instance (ty,tci,lc) -> check_instance scenv ty tci lc + | Th_instance(x, tci) -> check_instance scenv prefix x tci | Th_baserw (_,lc) -> if (lc = `Local && not scenv.sc_insec) then hierror "local base rewrite can only be declared inside section"; | Th_addrw (_,_,lc) -> if (lc = `Local && not scenv.sc_insec) then hierror "local hint rewrite can only be declared inside section"; - | Th_auto { locality } -> - if (locality = `Local && not scenv.sc_insec) then + | Th_auto { locality = lc; _ } -> + if (lc = `Local && not scenv.sc_insec) then hierror "local hint can only be declared inside section"; | Th_reduction _ -> () + | Th_alias _ -> () | Th_theory _ -> assert false - | Th_alias _ -> () (* FIXME:ALIAS *) let rec add_item ?(override_locality=None) (item : theory_item) (scenv : scenv) = - let item, scenv1 = add_item_ ~override_locality item scenv in + let item = match override_locality, scenv.sc_loca with + | Some lc, _ | None, (`Local as lc) -> set_lc_item lc item + | _ -> item + in + let scenv1 = add_item_ item scenv in begin match item.ti_item with | Th_theory (s,cth) -> if cth.cth_loca = `Local && not scenv.sc_insec then @@ -1577,9 +1612,10 @@ let add_decl_mod id mt scenv = d_mod = [`Declare; `Global]; d_modty = [`Global]; d_tc = [`Global]; + d_tci = [`Global]; } in let from = `Declare, `Module (mpath_abs id []) in - on_mty_mr (mkaenv scenv.sc_env (cb scenv from cd)) mt; + on_mty_mr (cb scenv from cd) mt; { scenv with sc_env = EcEnv.Mod.declare_local id mt scenv.sc_env; sc_items = SC_decl_mod (id, mt) :: scenv.sc_items } diff --git a/src/ecSmt.ml b/src/ecSmt.ml index b0230936ae..4228c0f774 100644 --- a/src/ecSmt.ml +++ b/src/ecSmt.ml @@ -266,14 +266,14 @@ let trans_tv lenv id = oget (Mid.find_opt id lenv.le_tv) (* -------------------------------------------------------------------- *) let lenv_of_tparams ts = - let trans_tv env (id : ty_param) = (* FIXME: TC HOOK *) + let trans_tv env ((id, _) : ty_param) = let tv = WTy.create_tvsymbol (preid id) in { env with le_tv = Mid.add id (WTy.ty_var tv) env.le_tv }, tv in List.map_fold trans_tv empty_lenv ts let lenv_of_tparams_for_hyp genv ts = - let trans_tv env (id : ty_param) = (* FIXME: TC HOOK *) + let trans_tv env ((id, _) : ty_param) = let ts = WTy.create_tysymbol (preid id) [] WTy.NoDef in genv.te_task <- WTask.add_ty_decl genv.te_task ts; { env with le_tv = Mid.add id (WTy.ty_app ts []) env.le_tv }, ts @@ -376,7 +376,7 @@ let rec trans_ty ((genv, lenv) as env) ty = | Tconstr (p, tys) -> let id = trans_pty genv p in - WTy.ty_app id (trans_tys env tys) + WTy.ty_app id (trans_tys env (List.fst tys)) | Tfun (t1, t2) -> WTy.ty_func (trans_ty env t1) (trans_ty env t2) @@ -400,22 +400,22 @@ and trans_tydecl genv (p, tydecl) = let ts, opts, decl = match tydecl.tyd_type with - | Abstract -> + | `Abstract _ -> let ts = WTy.create_tysymbol pid tparams WTy.NoDef in (ts, [], WDecl.create_ty_decl ts) - | Concrete ty -> + | `Concrete ty -> let ty = trans_ty (genv, lenv) ty in let ts = WTy.create_tysymbol pid tparams (WTy.Alias ty) in (ts, [], WDecl.create_ty_decl ts) - | Datatype dt -> + | `Datatype dt -> let ncs = List.length dt.tydt_ctors in let ts = WTy.create_tysymbol pid tparams WTy.NoDef in Hp.add genv.te_ty p ts; - let wdom = tconstr p (List.map tvar tydecl.tyd_params) in + let wdom = tconstr_tc p (etyargs_of_tparams tydecl.tyd_params) in let wdom = trans_ty (genv, lenv) wdom in let for_ctor (c, ctys) = @@ -429,12 +429,12 @@ and trans_tydecl genv (p, tydecl) = (ts, opts, WDecl.create_data_decl [ts, wdtype]) - | Record (_, rc) -> + | `Record (_, rc) -> let ts = WTy.create_tysymbol pid tparams WTy.NoDef in Hp.add genv.te_ty p ts; - let wdom = tconstr p (List.map tvar tydecl.tyd_params) in + let wdom = tconstr_tc p (etyargs_of_tparams tydecl.tyd_params) in let wdom = trans_ty (genv, lenv) wdom in let for_field (fname, fty) = @@ -712,6 +712,7 @@ and trans_app ((genv, lenv) as env : tenv * lenv) (f : form) args = | Fop (p, ts) -> let wop = trans_op genv p in + let ts = List.fst ts in let tys = List.map (trans_ty (genv,lenv)) ts in apply_wop genv wop tys args @@ -764,7 +765,7 @@ and trans_branch (genv, lenv) (p, _dty, tvs) (f, (cname, argsty)) = in let lenv, ws = trans_lvars genv lenv xs in - let wcty = trans_ty (genv, lenv) (tconstr p tvs) in + let wcty = trans_ty (genv, lenv) (tconstr_tc p tvs) in let ws = List.map WTerm.pat_var ws in let ws = WTerm.pat_app csymb ws wcty in let wf = trans_app (genv, lenv) f [] in @@ -1034,9 +1035,9 @@ and create_op ?(body = false) (genv : tenv) p = let lenv, wparams = lenv_of_tparams op.op_tparams in let dom, codom = EcEnv.Ty.signature genv.te_env op.op_ty in let textra = - List.filter (fun tv -> not (Mid.mem tv (EcTypes.Tvar.fv op.op_ty))) op.op_tparams in + List.filter (fun (tv, _) -> not (Mid.mem tv (EcTypes.Tvar.fv op.op_ty))) op.op_tparams in let textra = - List.map (fun tv -> trans_ty (genv,lenv) (tvar tv)) textra in + List.map (fun (tv, _) -> trans_ty (genv,lenv) (tvar tv)) textra in let wdom = trans_tys (genv, lenv) dom in let wcodom = if ER.EqTest.is_bool genv.te_env codom @@ -1188,17 +1189,44 @@ let trans_hyp ((genv, lenv) as env) (x, ty) = | LD_abs_st _ -> env -(* -------------------------------------------------------------------- *) -let lenv_of_hyps genv (hyps : hyps) : lenv = - let lenv = fst (lenv_of_tparams_for_hyp genv hyps.h_tvar) in - snd (List.fold_left trans_hyp (genv, lenv) (List.rev hyps.h_local)) - (* -------------------------------------------------------------------- *) let trans_axiom genv (p, ax) = (* if not ax.ax_nosmt then *) let lenv = fst (lenv_of_tparams ax.ax_tparams) in add_axiom (genv, lenv) (preid_p p) ax.ax_spec +(* For each typeclass constraint on a goal-context type parameter, pull + in the typeclass axioms (and those of all its ancestors) as Why3 + axioms. The axioms are registered globally with [`NoSmt] visibility + so the relevance heuristic skips them; we add them here on a + per-tparam basis so [smt()] (without explicit hints) can still close + goals over abstract TC carriers. *) +let trans_tc_axioms genv (tparams : ty_params) = + let seen = ref EcPath.Sp.empty in + let trans_one tc = + let ancestors = EcTypeClass.ancestors genv.te_env tc in + List.iter (fun anc -> + match EcEnv.TypeClass.by_path_opt anc.tc_name genv.te_env with + | None -> () + | Some tc_decl -> + List.iter (fun (axname, _) -> + let ax_path = + EcPath.pqoname (EcPath.prefix anc.tc_name) axname in + if not (EcPath.Sp.mem ax_path !seen) then begin + seen := EcPath.Sp.add ax_path !seen; + EcEnv.Ax.by_path_opt ax_path genv.te_env + |> Option.iter (fun ax -> trans_axiom genv (ax_path, ax)) + end + ) tc_decl.tc_axs + ) ancestors in + List.iter (fun (_, tcs) -> List.iter trans_one tcs) tparams + +(* -------------------------------------------------------------------- *) +let lenv_of_hyps genv (hyps : hyps) : lenv = + let lenv = fst (lenv_of_tparams_for_hyp genv hyps.h_tvar) in + trans_tc_axioms genv hyps.h_tvar; + snd (List.fold_left trans_hyp (genv, lenv) (List.rev hyps.h_local)) + (* -------------------------------------------------------------------- *) let mk_predb1 f l _ = f (Cast.force_prop (as_seq1 l)) let mk_predb2 f l _ = curry f (t2_map Cast.force_prop (as_seq2 l)) @@ -1623,6 +1651,16 @@ let dump_why3 (env : EcEnv.env) (filename : string) = let init hyps concl = let env = LDecl.toenv hyps in + (* Pre-reduce typeclass operators so the SMT translation sees ordinary + operators only. With concrete instances in scope this collapses + [(+)<:int + addmonoid>] into [Int.(+)] and similar. Polymorphic TC + ops over abstract carriers stay folded; SMT will treat them as + opaque, which is consistent with their hypotheses being SMT-encoded + similarly. We restrict the reduction to TC unfolding (delta_tc) to + avoid over-simplifying the goal in ways that defeat SMT hints. *) + let concl = + let ri = { EcReduction.no_red with delta_tc = true } in + EcReduction.simplify ri hyps concl in let hyps = LDecl.tohyps hyps in let task = create_global_task () in let known = Lazy.force core_theories in diff --git a/src/ecSubst.ml b/src/ecSubst.ml index 107aba6e06..b674da87c8 100644 --- a/src/ecSubst.ml +++ b/src/ecSubst.ml @@ -1,5 +1,6 @@ (* -------------------------------------------------------------------- *) open EcUtils +open EcMaps open EcAst open EcTypes open EcDecl @@ -27,11 +28,18 @@ exception InconsistentSubst type subst = { sb_module : EcPath.mpath Mid.t; sb_path : EcPath.path Mp.t; - sb_tyvar : ty Mid.t; + sb_tyvar : etyarg Mid.t; sb_elocal : expr Mid.t; sb_flocal : EcCoreFol.form Mid.t; sb_fmem : EcIdent.t Mid.t; - sb_tydef : (EcIdent.t list * ty) Mp.t; + (* [sb_tydef p ↦ (params, body, tcs)] mirrors [sb_tyvar] for path-level + type aliases: alongside the body, [tcs] supplies one [tcwitness] + per TC constraint that [p] declared, expressed in terms of [body]. + For non-TC bindings (most callers) [tcs = []]. The witness list + lets [subst_tcw] resolve [`Abs p; offset; lift] to a concrete + witness on [body], the same way the [`Var a] case resolves through + [sb_tyvar]'s [tcs]. *) + sb_tydef : (EcIdent.t list * ty * tcwitness list) Mp.t; sb_def : (EcIdent.t list * [`Op of expr | `Pred of form]) Mp.t; sb_moddef : EcPath.mpath Mp.t; (* Only top-level modules *) } @@ -138,17 +146,17 @@ let has_def (s : subst) (p : EcPath.path) = Mp.mem p s.sb_def (* -------------------------------------------------------------------- *) -let add_tyvar (s : subst) (x : EcIdent.t) (ty : ty) = +let add_tyvar (s : subst) (x : EcIdent.t) (ety : etyarg) = (* FIXME: check name clash *) let merger = function - | None -> Some ty + | None -> Some ety | Some _ -> raise (SubstNameClash (`Ident x)) in { s with sb_tyvar = Mid.change merger x s.sb_tyvar } (* -------------------------------------------------------------------- *) -let add_tyvars (s : subst) (xs : EcIdent.t list) (tys : ty list) = - List.fold_left2 add_tyvar s xs tys +let add_tyvars (s : subst) (xs : (EcIdent.t * etyarg) list) = + List.fold_left (fun s (x, ety) -> add_tyvar s x ety) s xs (* -------------------------------------------------------------------- *) let rec subst_ty (s : subst) (ty : ty) = @@ -157,23 +165,25 @@ let rec subst_ty (s : subst) (ty : ty) = tglob (EcPath.mget_ident (subst_mpath s (EcPath.mident mp))) | Tunivar _ -> - ty (* FIXME *) + ty | Tvar a -> - Mid.find_def ty a s.sb_tyvar + Mid.find_opt a s.sb_tyvar + |> Option.map fst + |> Option.value ~default:ty | Ttuple tys -> ttuple (subst_tys s tys) - | Tconstr (p, tys) -> begin - let tys = subst_tys s tys in + | Tconstr (p, etys) -> begin + let etys = subst_etyargs s etys in match Mp.find_opt p s.sb_tydef with | None -> - tconstr (subst_path s p) tys + tconstr_tc (subst_path s p) etys - | Some (args, body) -> - let s = List.fold_left2 add_tyvar empty args tys in + | Some (args, body, _tcs) -> + let s = List.fold_left2 add_tyvar empty args etys in subst_ty s body end @@ -184,6 +194,59 @@ let rec subst_ty (s : subst) (ty : ty) = and subst_tys (s : subst) (tys : ty list) = List.map (subst_ty s) tys +(* -------------------------------------------------------------------- *) +and subst_etyarg (s : subst) ((ty, tcws) : etyarg) : etyarg = + (subst_ty s ty, subst_tcws s tcws) + +(* -------------------------------------------------------------------- *) +and subst_etyargs (s : subst) (tyargs : etyarg list) : etyarg list = + List.map (subst_etyarg s) tyargs + +(* -------------------------------------------------------------------- *) +and subst_tcw (s : subst) (tcw : tcwitness) = + match tcw with + | TCIUni _ -> + tcw + + | TCIConcrete ({ etyargs; path; _ } as c) -> + let path = subst_path s path in + let etyargs = subst_etyargs s etyargs in + TCIConcrete { c with etyargs; path } + + | TCIAbstract { support = `Var a; offset; lift } -> + let resolved = + Option.bind (Mid.find_opt a s.sb_tyvar) (fun (_, tcs) -> + Option.map (fun tcw -> bump_lift lift tcw) (List.nth_opt tcs offset)) in + Option.value ~default:tcw resolved + + | TCIAbstract ({ support = `Abs p; offset; lift } as tcw) -> + match Mp.find_opt p s.sb_tydef with + | None -> + TCIAbstract { tcw with support = `Abs (subst_path s p) } + + | Some (_, _body, tcs) when offset < List.length tcs -> + (* Mirror of the [`Var a] case: when the binding carries + [tcwitness]es for [p]'s declared TCs, look up the offset-th + one and bump-lift the embedded path. This is what closes the + gap when an instance/clone substitutes an abstract type [p] + that had TC constraints — without it the offset references + constraints that no longer exist on [body], leaving the + witness pointing nowhere. *) + bump_lift lift (subst_tcw s (List.nth tcs offset)) + + | Some (_, body, _) -> + match body.ty_node with + | Tvar a -> + TCIAbstract { support = `Var a; offset; lift } + | Tconstr (p', _) -> + TCIAbstract { support = `Abs p'; offset; lift } + | _ -> + assert false (* FIXME:TC: substitute via concrete instance lookup *) + +(* -------------------------------------------------------------------- *) +and subst_tcws (s : subst) (tcws : tcwitness list) : tcwitness list = + List.map (subst_tcw s) tcws + (* -------------------------------------------------------------------- *) let add_module (s : subst) (x : EcIdent.t) (m : EcPath.mpath) = let merger = function @@ -268,9 +331,9 @@ let add_path (s : subst) ~src ~dst = assert (Mp.find_opt src s.sb_path = None); { s with sb_path = Mp.add src dst s.sb_path } -let add_tydef (s : subst) p (ids, ty) = +let add_tydef (s : subst) p (typ, ty, tcs) = assert (Mp.find_opt p s.sb_tydef = None); - { s with sb_tydef = Mp.add p (ids, ty) s.sb_tydef } + { s with sb_tydef = Mp.add p (typ, ty, tcs) s.sb_tydef } let add_opdef (s : subst) p (ids, f) = assert (Mp.find_opt p s.sb_def = None); @@ -318,51 +381,80 @@ let subst_expr_lpattern (s : subst) (lp : lpattern) = (* -------------------------------------------------------------------- *) let rec subst_expr (s : subst) (e : expr) = + let mk (node : expr_node) = + let ty = subst_ty s e.e_ty in + mk_expr node ty in + match e.e_node with + | Eint _ -> + mk e.e_node + | Elocal id -> begin match Mid.find id s.sb_elocal with | aout -> aout - | exception Not_found -> e_local id (subst_ty s e.e_ty) + | exception Not_found -> mk (Elocal id) end | Evar pv -> - e_var (subst_progvar s pv) (subst_ty s e.e_ty) - - | Eapp ({ e_node = Eop (p, tys) }, args) when has_opdef s p -> - let tys = subst_tys s tys in - let ty = subst_ty s e.e_ty in - let body = oget (get_opdef s p) in - let args = List.map (subst_expr s) args in - subst_eop ty tys args body - - | Eop (p, tys) when has_opdef s p -> - let tys = subst_tys s tys in - let ty = subst_ty s e.e_ty in - let body = oget (get_opdef s p) in - subst_eop ty tys [] body - - | Eop (p, tys) -> - let p = subst_path s p in - let tys = subst_tys s tys in - let ty = subst_ty s e.e_ty in - e_op p tys ty + mk (Evar (subst_progvar s pv)) + + | Eapp ({ e_node = Eop (p, tyargs) }, args) when has_opdef s p -> + let tyargs = subst_etyargs s tyargs in + let ty = subst_ty s e.e_ty in + let body = oget (get_opdef s p) in + let args = List.map (subst_expr s) args in + subst_eop ty tyargs args body + + | Eapp (hd, args) -> + let hd = subst_expr s hd in + let args = List.map (subst_expr s) args in + mk (Eapp (hd, args)) + + | Eop (p, tyargs) when has_opdef s p -> + let tys = subst_etyargs s tyargs in + let ty = subst_ty s e.e_ty in + let body = oget (get_opdef s p) in + subst_eop ty tys [] body + + | Eop (p, tyargs) -> + let p = subst_path s p in + let tyargs = subst_etyargs s tyargs in + mk (Eop (p, tyargs)) + + | Eif (c, e1, e2) -> + let c = subst_expr s c in + let e1 = subst_expr s e1 in + let e2 = subst_expr s e2 in + mk (Eif (c, e1, e2)) + + | Ematch (c, bs, ty) -> + let c = subst_expr s c in + let bs = List.map (subst_expr s) bs in + let ty = subst_ty s ty in + mk (Ematch (c, bs, ty)) + + | Eproj (sube, (i : int)) -> + let sube = subst_expr s sube in + mk (Eproj (sube, i)) + + | Etuple es -> + let es = List.map (subst_expr s) es in + mk (Etuple es) | Elet (lp, e1, e2) -> - let e1 = subst_expr s e1 in - let s, lp = subst_expr_lpattern s lp in - let e2 = subst_expr s e2 in - e_let lp e1 e2 + let e1 = subst_expr s e1 in + let s, lp = subst_expr_lpattern s lp in + let e2 = subst_expr s e2 in + mk (Elet (lp, e1, e2)) - | Equant (q, b, e1) -> - let s, b = fresh_elocals s b in - let e1 = subst_expr s e1 in - e_quantif q b e1 - - | _ -> e_map (subst_ty s) (subst_expr s) e + | Equant (q, b, bd) -> + let s, b = fresh_elocals s b in + let bd = subst_expr s bd in + mk (Equant (q, b, bd)) (* -------------------------------------------------------------------- *) and subst_eop ety tys args (tyids, e) = - let s = add_tyvars empty tyids tys in + let s = add_tyvars empty (List.combine tyids tys) in let (s, args, e) = match e.e_node with @@ -475,59 +567,83 @@ let subst_form_lpattern (s : subst) (lp : lpattern) = (* -------------------------------------------------------------------- *) let rec subst_form (s : subst) (f : form) = + let mk (node : f_node) = + let ty = subst_ty s f.f_ty in + mk_form node ty in + match f.f_node with - | Fquant (q, b, f1) -> - let s, b = fresh_glocals s b in - let e1 = subst_form s f1 in - f_quant q b e1 + | Fint _ -> + mk (f.f_node) + + | Fquant (q, b, bd) -> + let s, b = fresh_glocals s b in + let bd = subst_form s bd in + mk (Fquant (q, b, bd)) | Fmatch (f, bs, ty) -> - let f = subst_form s f in - let bs = List.map (subst_form s) bs in - let ty = subst_ty s ty in - f_match f bs ty + let f = subst_form s f in + let bs = List.map (subst_form s) bs in + let ty = subst_ty s ty in + mk (Fmatch (f, bs, ty)) | Flet (lp, f, body) -> - let f = subst_form s f in - let s, lp = subst_form_lpattern s lp in - let body = subst_form s body in - f_let lp f body + let f = subst_form s f in + let s, lp = subst_form_lpattern s lp in + let body = subst_form s body in + mk (Flet (lp, f, body)) | Flocal x -> begin - match Mid.find x s.sb_flocal with - | aout -> aout - | exception Not_found -> f_local x (subst_ty s f.f_ty) - end + match Mid.find x s.sb_flocal with + | aout -> aout + | exception Not_found -> mk (Flocal x) + end | Fpvar (pv, m) -> - let pv = subst_progvar s pv in - let ty = subst_ty s f.f_ty in - let m = subst_mem s m in - (f_pvar pv ty m).inv + let pv = subst_progvar s pv in + let m = subst_mem s m in + mk (Fpvar (pv, m)) | Fglob (mp, m) -> - let mp = EcPath.mget_ident (subst_mpath s (EcPath.mident mp)) in - let m = subst_mem s m in - (f_glob mp m).inv - - | Fapp ({ f_node = Fop (p, tys) }, args) when has_def s p -> - let tys = subst_tys s tys in - let ty = subst_ty s f.f_ty in - let body = oget (get_def s p) in - let args = List.map (subst_form s) args in - subst_fop ty tys args body - - | Fop (p, tys) when has_def s p -> - let tys = subst_tys s tys in - let ty = subst_ty s f.f_ty in - let body = oget (get_def s p) in - subst_fop ty tys [] body - - | Fop (p, tys) -> - let p = subst_path s p in - let tys = subst_tys s tys in - let ty = subst_ty s f.f_ty in - f_op p tys ty + let mp = EcPath.mget_ident (subst_mpath s (EcPath.mident mp)) in + let m = subst_mem s m in + mk (Fglob (mp, m)) + + | Fapp ({ f_node = Fop (p, tyargs) }, args) when has_def s p -> + let tys = subst_etyargs s tyargs in + let ty = subst_ty s f.f_ty in + let body = oget (get_def s p) in + let args = List.map (subst_form s) args in + subst_fop ty tys args body + + | Fapp (hd, args) -> + let hd = subst_form s hd in + let args = List.map (subst_form s) args in + mk (Fapp (hd, args)) + + | Fop (p, tyargs) when has_def s p -> + let tyargs = subst_etyargs s tyargs in + let ty = subst_ty s f.f_ty in + let body = oget (get_def s p) in + subst_fop ty tyargs [] body + + | Fop (p, tyargs) -> + let p = subst_path s p in + let tyargs = subst_etyargs s tyargs in + mk (Fop (p, tyargs)) + + | Fif (c, f1, f2) -> + let c = subst_form s c in + let f1 = subst_form s f1 in + let f2 = subst_form s f2 in + mk (Fif (c, f1, f2)) + + | Ftuple fs -> + let fs = List.map (subst_form s) fs in + mk (Ftuple fs) + + | Fproj (subf, (i : int)) -> + let subf = subst_form s subf in + mk (Fproj (subf, i)) | FhoareF hf -> let hf_f = subst_xpath s hf.hf_f in @@ -610,12 +726,9 @@ let rec subst_form (s : subst) (f : form) = let pr_event = map_ss_inv1 (subst_form s) pr_event in f_pr pr_mem pr_fun pr_args pr_event - | Fif _ | Fint _ | Ftuple _ | Fproj _ | Fapp _ -> - f_map (subst_ty s) (subst_form s) f - (* -------------------------------------------------------------------- *) and subst_fop fty tys args (tyids, f) = - let s = add_tyvars empty tyids tys in + let s = add_tyvars empty (List.combine tyids tys) in let (s, args, f) = match f.f_node with @@ -830,10 +943,20 @@ let subst_top_module (s : subst) (m : top_module_expr) = tme_loca = m.tme_loca; } (* -------------------------------------------------------------------- *) -let fresh_tparam (s : subst) (x : ty_param) = +let subst_typeclass (s : subst) (tc : typeclass) = + { tc_name = subst_path s tc.tc_name; + tc_args = subst_etyargs s tc.tc_args; } + +(* -------------------------------------------------------------------- *) +let fresh_tparam (s : subst) ((x, tcs) : ty_param) = let newx = EcIdent.fresh x in - let s = add_tyvar s x (tvar newx) in - (s, newx) + let tcs = List.map (subst_typeclass s) tcs in + let tcw = + let mk (offset : int) = + TCIAbstract { support = `Var newx; offset; lift = [] } + in List.mapi (fun i _ -> mk i) tcs in + let s = add_tyvar s x (tvar newx, tcw) in + (s, (newx, tcs)) (* -------------------------------------------------------------------- *) let fresh_tparams (s : subst) (tparams : ty_params) = @@ -848,30 +971,37 @@ let subst_genty (s : subst) (tparams, ty) = (* -------------------------------------------------------------------- *) let subst_tydecl_body (s : subst) (tyd : ty_body) = match tyd with - | Abstract -> - Abstract + | `Abstract tc -> + `Abstract (List.map (subst_typeclass s) tc) - | Concrete ty -> - Concrete (subst_ty s ty) + | `Concrete ty -> + `Concrete (subst_ty s ty) - | Datatype dtype -> + | `Datatype dtype -> let dtype = { tydt_ctors = List.map (snd_map (List.map (subst_ty s))) dtype.tydt_ctors; tydt_schelim = subst_form s dtype.tydt_schelim; tydt_schcase = subst_form s dtype.tydt_schcase; } - in Datatype dtype + in `Datatype dtype - | Record (scheme, fields) -> - Record (subst_form s scheme, List.map (snd_map (subst_ty s)) fields) + | `Record (scheme, fields) -> + `Record (subst_form s scheme, List.map (snd_map (subst_ty s)) fields) (* -------------------------------------------------------------------- *) let subst_tydecl (s : subst) (tyd : tydecl) = let s, tparams = fresh_tparams s tyd.tyd_params in let body = subst_tydecl_body s tyd.tyd_type in + let tyd_subtype = + Option.map + (fun (carrier, pred) -> (subst_ty s carrier, subst_form s pred)) + tyd.tyd_subtype + in { tyd_params = tparams; tyd_type = body; - tyd_loca = tyd.tyd_loca; } + tyd_resolve = tyd.tyd_resolve; + tyd_loca = tyd.tyd_loca; + tyd_subtype; } (* -------------------------------------------------------------------- *) let rec subst_op_kind (s : subst) (kind : operator_kind) = @@ -911,8 +1041,9 @@ and subst_op_body (s : subst) (bd : opbody) = opf_resty = subst_ty s opfix.opf_resty; opf_struct = opfix.opf_struct; opf_branches = subst_branches es opfix.opf_branches; } + + | OP_TC (p, n) -> OP_TC (subst_path s p, n) | OP_Exn tys -> OP_Exn (List.map (subst_ty s) tys) - | OP_TC -> OP_TC and subst_branches (s : subst) = function | OPB_Leaf (locals, e) -> @@ -1007,19 +1138,39 @@ let subst_field (s : subst) cr = f_inv = subst_path s cr.f_inv; f_div = omap (subst_path s) cr.f_div; } -(* -------------------------------------------------------------------- *) -let subst_instance (s : subst) tci = - match tci with - | `Ring cr -> `Ring (subst_ring s cr) - | `Field cr -> `Field (subst_field s cr) - | `General p -> `General (subst_path s p) - (* -------------------------------------------------------------------- *) let subst_tc (s : subst) tc = - let tc_prt = omap (subst_path s) tc.tc_prt in + let s, tc_tparams = fresh_tparams s tc.tc_tparams in + let tc_prts = + List.map (fun (p, ren) -> (subst_typeclass s p, ren)) tc.tc_prts in let tc_ops = List.map (snd_map (subst_ty s)) tc.tc_ops in let tc_axs = List.map (snd_map (subst_form s)) tc.tc_axs in - { tc_prt; tc_ops; tc_axs; tc_loca = tc.tc_loca } + { tc_tparams; tc_prts; tc_ops; tc_axs; tc_loca = tc.tc_loca } + +(* -------------------------------------------------------------------- *) +let subst_tcibody (s : subst) (tci : tcibody) = + match tci with + | `Ring cr -> `Ring (subst_ring s cr) + | `Field cr -> `Field (subst_field s cr) + + | `General (tc, syms) -> + let tc = subst_typeclass s tc in + let syms = + Option.map + (Mstr.map (fun (p, tys) -> (subst_path s p, subst_etyargs s tys))) + syms in + `General (tc, syms) + + +(* -------------------------------------------------------------------- *) +let subst_tcinstance (s : subst) (tci : tcinstance) = + let s, tci_params = fresh_tparams s tci.tci_params in + let tci_type = subst_ty s tci.tci_type in + let tci_instance = subst_tcibody s tci.tci_instance in + let tci_local = tci.tci_local in + let tci_parents = tci.tci_parents in + + { tci_params; tci_type; tci_instance; tci_local; tci_parents; } (* -------------------------------------------------------------------- *) @@ -1052,8 +1203,8 @@ let rec subst_theory_item_r (s : subst) (item : theory_item_r) = | Th_export (p, lc) -> Th_export (subst_path s p, lc) - | Th_instance (ty, tci, lc) -> - Th_instance (subst_genty s ty, subst_instance s tci, lc) + | Th_instance (x, tci) -> + Th_instance (x, subst_tcinstance s tci) | Th_baserw _ -> item @@ -1073,6 +1224,9 @@ let rec subst_theory_item_r (s : subst) (item : theory_item_r) = | Th_alias (name, target) -> Th_alias (name, subst_path s target) + | Th_typeclass (x, tc) -> + Th_typeclass (x, subst_tc s tc) + (* -------------------------------------------------------------------- *) and subst_theory (s : subst) (items : theory) = List.map (subst_theory_item s) items @@ -1109,17 +1263,17 @@ let subst_inv (s : subst) (inv : inv) = | Inv_hs inv -> Inv_hs (subst_hs_inv s inv) (* -------------------------------------------------------------------- *) -let init_tparams (params : (EcIdent.t * ty) list) : subst = - List.fold_left (fun s (x, ty) -> add_tyvar s x ty) empty params +let init_tparams (params : (EcIdent.t * etyarg) list) : subst = + add_tyvars empty params (* -------------------------------------------------------------------- *) -let open_oper op tys = - let s = List.combine op.op_tparams tys in +let open_oper (op : operator) (tys : etyarg list) : ty * operator_kind = + let s = List.combine (List.map fst op.op_tparams) tys in let s = init_tparams s in (subst_ty s op.op_ty, subst_op_kind s op.op_kind) -let open_tydecl tyd tys = - let s = List.combine tyd.tyd_params tys in +let open_tydecl (tyd : tydecl) (tys : etyarg list) : EcDecl.ty_body = + let s = List.combine (List.map fst tyd.tyd_params) tys in let s = init_tparams s in subst_tydecl_body s tyd.tyd_type diff --git a/src/ecSubst.mli b/src/ecSubst.mli index 4f6b8c2123..fdb4b6f59d 100644 --- a/src/ecSubst.mli +++ b/src/ecSubst.mli @@ -25,7 +25,8 @@ val is_empty : subst -> bool (* -------------------------------------------------------------------- *) val add_module : subst -> EcIdent.t -> mpath -> subst val add_path : subst -> src:path -> dst:path -> subst -val add_tydef : subst -> path -> (EcIdent.t list * ty) -> subst +val add_tydef : subst -> path -> (EcIdent.t list * ty * tcwitness list) -> subst +val add_tyvar : subst -> EcIdent.t -> etyarg -> subst val add_opdef : subst -> path -> (EcIdent.t list * expr) -> subst val add_pddef : subst -> path -> (EcIdent.t list * form) -> subst val add_moddef : subst -> src:path -> dst:mpath -> subst (* Only concrete modules *) @@ -39,14 +40,15 @@ val rename_flocal : subst -> EcIdent.t -> EcIdent.t -> ty -> subst val freshen_type : (ty_params * ty) -> (ty_params * ty) (* -------------------------------------------------------------------- *) -val subst_theory : subst -> theory -> theory -val subst_ax : subst -> axiom -> axiom -val subst_op : subst -> operator -> operator -val subst_op_body : subst -> opbody -> opbody -val subst_tydecl : subst -> tydecl -> tydecl -val subst_theory : subst -> theory -> theory -val subst_branches : subst -> opbranches -> opbranches -val subst_exception : subst -> exception_ -> exception_ +val subst_theory : subst -> theory -> theory +val subst_ax : subst -> axiom -> axiom +val subst_op : subst -> operator -> operator +val subst_op_body : subst -> opbody -> opbody +val subst_tydecl : subst -> tydecl -> tydecl +val subst_tc : subst -> tc_decl -> tc_decl +val subst_tcinstance : subst -> tcinstance -> tcinstance +val subst_branches : subst -> opbranches -> opbranches +val subst_exception : subst -> exception_ -> exception_ (* -------------------------------------------------------------------- *) val subst_path : subst -> path -> path @@ -65,23 +67,26 @@ val subst_mod_restr : subst -> mod_restr -> mod_restr val subst_oracle_infos : subst -> oracle_infos -> oracle_infos (* -------------------------------------------------------------------- *) -val subst_gty : subst -> gty -> gty -val subst_genty : subst -> (ty_params * ty) -> (ty_params * ty) -val subst_ty : subst -> ty -> ty -val subst_form : subst -> form -> form -val subst_expr : subst -> expr -> expr -val subst_stmt : subst -> stmt -> stmt - -val subst_progvar : subst -> prog_var -> prog_var -val subst_mem : subst -> EcIdent.t -> EcIdent.t -val subst_flocal : subst -> form -> form -val subst_ss_inv : subst -> ss_inv -> ss_inv -val subst_ts_inv : subst -> ts_inv -> ts_inv -val subst_inv : subst -> inv -> inv +val subst_mem : subst -> EcIdent.t -> EcIdent.t +val subst_flocal : subst -> form -> form +val subst_gty : subst -> gty -> gty +val subst_genty : subst -> (ty_params * ty) -> (ty_params * ty) +val subst_ty : subst -> ty -> ty +val subst_etyarg : subst -> etyarg -> etyarg +val subst_tcw : subst -> tcwitness -> tcwitness +val subst_form : subst -> form -> form +val subst_expr : subst -> expr -> expr +val subst_stmt : subst -> stmt -> stmt +val subst_ss_inv : subst -> ss_inv -> ss_inv +val subst_ts_inv : subst -> ts_inv -> ts_inv +val subst_inv : subst -> inv -> inv (* -------------------------------------------------------------------- *) -val open_oper : operator -> ty list -> ty * operator_kind -val open_tydecl : tydecl -> ty list -> ty_body +val open_oper : operator -> etyarg list -> ty * operator_kind +val open_tydecl : tydecl -> etyarg list -> ty_body + +(* -------------------------------------------------------------------- *) +val fresh_tparams : subst -> ty_params -> subst * ty_params (* -------------------------------------------------------------------- *) val ss_inv_rebind : ss_inv -> memory -> ss_inv diff --git a/src/ecTcCanonical.ml b/src/ecTcCanonical.ml new file mode 100644 index 0000000000..51d0947542 --- /dev/null +++ b/src/ecTcCanonical.ml @@ -0,0 +1,195 @@ +(* -------------------------------------------------------------------- *) +(* Canonical form of [TCIAbstract] witnesses. + + For a [TCIAbstract { support; offset; lift }] there can be multiple + path encodings reaching the same [(target_class, cumulative_renaming)] + when the support's TC class has diamond inheritance. The framework + relies on structural equality of [tcwitness] in many places, so two + semantically-equivalent encodings being structurally distinct breaks + downstream reasoning. + + This module builds, for any [tcs : typeclass list] (the TC bounds of + a carrier), the table of canonical [(offset, lift)] paths reaching + each [(target_class, renaming)]. The "canonical" path is the + FIRST-IN-BFS-ORDER encounter, which gives a deterministic choice + without needing to enumerate all paths. + + Phase A of Stage 2 turns this table into the single source of truth + for path-encoded witnesses. Construction sites consult it; matching + / convertibility sites then need no canonicalisation since structural + compare on canonical encodings is correct. *) + +open EcAst +open EcUtils + +(* -------------------------------------------------------------------- *) +(* Compose a parent-edge renaming [outer] with the cumulative + ancestor-to-child renaming [inner]. *) +let compose_renaming + ~(outer : (EcSymbols.symbol * EcSymbols.symbol) list) + ~(inner : (EcSymbols.symbol * EcSymbols.symbol) list) + : (EcSymbols.symbol * EcSymbols.symbol) list += + let inner_map = EcMaps.Mstr.of_list inner in + let from_outer = + List.map + (fun (gp_name, p_name) -> + let c_name = odfl p_name (EcMaps.Mstr.find_opt p_name inner_map) in + (gp_name, c_name)) + outer in + let outer_p_names = + List.fold_left (fun s (_, p) -> EcMaps.Sstr.add p s) EcMaps.Sstr.empty outer in + let outer_gp_names = + List.fold_left (fun s (gp, _) -> EcMaps.Sstr.add gp s) EcMaps.Sstr.empty outer in + let from_inner = + List.filter_map + (fun (p_name, c_name) -> + if EcMaps.Sstr.mem p_name outer_p_names || EcMaps.Sstr.mem p_name outer_gp_names + then None + else Some (p_name, c_name)) + inner in + from_outer @ from_inner + +(* -------------------------------------------------------------------- *) +(* Renaming equality (set of pairs, order-insensitive). *) +let ren_equal + (r1 : (EcSymbols.symbol * EcSymbols.symbol) list) + (r2 : (EcSymbols.symbol * EcSymbols.symbol) list) + : bool += + List.length r1 = List.length r2 + && List.for_all (fun (a, b) -> + match List.assoc_opt a r2 with + | Some b' -> b = b' + | None -> false) r1 + +(* -------------------------------------------------------------------- *) +(* Walk a lift path from [start], composing renamings. Returns + [Some (target_tc, cumulative_renaming)] iff every index in [lift] + is in range. *) +let walk_path (env : EcEnv.env) (start : typeclass) (lift : int list) + : (typeclass * (EcSymbols.symbol * EcSymbols.symbol) list) option += + let rec aux tc ren = function + | [] -> Some (tc, ren) + | i :: rest -> + let decl = EcEnv.TypeClass.by_path tc.tc_name env in + match List.nth_opt decl.tc_prts i with + | None -> None + | Some (parent, p_ren) -> + let subst = + List.fold_left2 + (fun s (a, _) etyarg -> EcIdent.Mid.add a etyarg s) + EcIdent.Mid.empty decl.tc_tparams tc.tc_args in + let parent = EcCoreSubst.Tvar.subst_tc subst parent in + let ren' = compose_renaming ~outer:p_ren ~inner:ren in + aux parent ren' rest + in aux start [] lift + +(* -------------------------------------------------------------------- *) +(* Build the canonical-paths table: for each [(target_tc_name, ren)] + reachable from [tcs], record the [(offset, lift)] of the FIRST + path encountered in BFS order. Repeat encounters are skipped, so + each [(target, ren)] gets exactly one canonical encoding. *) +type canon_key = EcPath.path * (EcSymbols.symbol * EcSymbols.symbol) list +type canon_path = int * int list +type canon_table = (canon_key * canon_path) list + +let canonical_table + (env : EcEnv.env) + (tcs : typeclass list) + : canon_table += + let recorded = ref [] in + let already ((target_path, ren) : canon_key) = + List.exists + (fun ((p, r), _) -> EcPath.p_equal p target_path && ren_equal r ren) + !recorded in + let record (target_path : EcPath.path) (ren : _) (offset : int) (lift : int list) = + let key = (target_path, ren) in + if not (already key) then + recorded := (key, (offset, lift)) :: !recorded in + let rec bfs frontier = + match frontier with + | [] -> () + | (tc, ren, offset, rev_lift) :: rest -> + record tc.tc_name ren offset (List.rev rev_lift); + let decl = EcEnv.TypeClass.by_path tc.tc_name env in + let subst = + List.fold_left2 + (fun s (a, _) etyarg -> EcIdent.Mid.add a etyarg s) + EcIdent.Mid.empty decl.tc_tparams tc.tc_args in + let next = + List.mapi + (fun i (parent, p_ren) -> + let parent = EcCoreSubst.Tvar.subst_tc subst parent in + let ren' = compose_renaming ~outer:p_ren ~inner:ren in + (parent, ren', offset, i :: rev_lift)) + decl.tc_prts in + bfs (rest @ next) in + let initial = + List.mapi (fun i tc -> (tc, [], i, [])) tcs in + bfs initial; + List.rev !recorded + +(* -------------------------------------------------------------------- *) +(* Canonical [(offset, lift)] reaching [(target_path, target_ren)] from + [tcs], using the BFS-first table. *) +let canonical_path + (env : EcEnv.env) + (tcs : typeclass list) + (target : EcPath.path) + (target_ren : (EcSymbols.symbol * EcSymbols.symbol) list) + : canon_path option += + let table = canonical_table env tcs in + List.find_opt + (fun ((p, r), _) -> EcPath.p_equal p target && ren_equal r target_ren) + table + |> Option.map snd + +(* -------------------------------------------------------------------- *) +(* Look up the TC constraints of an abstract-witness support. *) +let support_tcs (env : EcEnv.env) + (sup : [ `Var of EcIdent.t | `Abs of EcPath.path ]) + : typeclass list option += + match sup with + | `Abs p -> begin + match EcEnv.Ty.by_path_opt p env with + | Some { tyd_type = `Abstract tcs; _ } -> Some tcs + | _ -> None + end + | `Var _ -> + (* [`Var v] supports require the surrounding context's tparam-to-TC + map. Without it (only an [EcEnv.env] is available globally), we + can't canonicalise; leave the witness untouched. *) + None + +(* -------------------------------------------------------------------- *) +(* Canonicalise a single tcwitness via the table. Only [TCIAbstract] is + changed, only when its support's TC list is reachable from [env] and + the path's target/renaming has a recorded canonical encoding. *) +let rec canonicalise_witness (env : EcEnv.env) (tcw : tcwitness) : tcwitness = + match tcw with + | TCIUni _ -> tcw + + | TCIConcrete c -> + let etyargs = + List.map (fun (ty, ws) -> (ty, List.map (canonicalise_witness env) ws)) + c.etyargs in + TCIConcrete { c with etyargs } + + | TCIAbstract { support; offset; lift } -> begin + match support_tcs env support with + | None -> tcw + | Some tcs -> + match walk_path env (List.nth tcs offset) lift with + | None -> tcw + | Some (target, ren) -> + match canonical_path env tcs target.tc_name ren with + | None -> tcw + | Some (o', l') -> + if o' = offset && l' = lift then tcw + else TCIAbstract { support; offset = o'; lift = l' } + end diff --git a/src/ecThCloning.ml b/src/ecThCloning.ml index e55075da2a..e8aca6cbfc 100644 --- a/src/ecThCloning.ml +++ b/src/ecThCloning.ml @@ -72,6 +72,7 @@ type evclone = { evc_ops : (xop_override located) Msym.t; evc_preds : (xpr_override located) Msym.t; evc_abbrevs : (nt_override located) Msym.t; + evc_modexprs : (me_override located) Msym.t; evc_modtypes : (mt_override located) Msym.t; evc_lemmas : evlemma; evc_ths : (evclone * bool) Msym.t; @@ -93,6 +94,7 @@ let evc_empty = evc_ops = Msym.empty; evc_preds = Msym.empty; evc_abbrevs = Msym.empty; + evc_modexprs = Msym.empty; evc_modtypes = Msym.empty; evc_lemmas = evl; evc_ths = Msym.empty; } @@ -523,6 +525,7 @@ end = struct | Th_reduction _ -> (proofs, evc) | Th_auto _ -> (proofs, evc) | Th_alias _ -> (proofs, evc) + | Th_typeclass _ -> (proofs, evc) and doit prefix (proofs, evc) dth = doit_r prefix (proofs, evc) dth.ti_item diff --git a/src/ecThCloning.mli b/src/ecThCloning.mli index 4720f2cbcd..346c47fd31 100644 --- a/src/ecThCloning.mli +++ b/src/ecThCloning.mli @@ -58,6 +58,7 @@ type evclone = { evc_ops : (xop_override located) Msym.t; evc_preds : (xpr_override located) Msym.t; evc_abbrevs : (nt_override located) Msym.t; + evc_modexprs : (me_override located) Msym.t; evc_modtypes : (mt_override located) Msym.t; evc_lemmas : evlemma; evc_ths : (evclone * bool) Msym.t; diff --git a/src/ecTheory.ml b/src/ecTheory.ml index c4bcbfe19c..47e9f49917 100644 --- a/src/ecTheory.ml +++ b/src/ecTheory.ml @@ -1,6 +1,7 @@ (* -------------------------------------------------------------------- *) open EcUtils open EcSymbols +open EcMaps open EcPath open EcAst open EcTypes @@ -26,7 +27,8 @@ and theory_item_r = | Th_module of top_module_expr | Th_theory of (symbol * ctheory) | Th_export of EcPath.path * is_local - | Th_instance of (ty_params * EcTypes.ty) * tcinstance * is_local + | Th_instance of (symbol option * tcinstance) + | Th_typeclass of (symbol * tc_decl) | Th_baserw of symbol * is_local | Th_addrw of EcPath.path * EcPath.path list * is_local | Th_reduction of (EcPath.path * rule_option * rule option) list @@ -44,8 +46,27 @@ and ctheory = { cth_source : thsource option; } -and tcinstance = [ `Ring of ring | `Field of field | `General of path ] -and thmode = [ `Abstract | `Concrete ] +and tcinstance = { + tci_params : ty_params; + tci_type : ty; + tci_instance : tcibody; + tci_local : locality; + (* When this instance was synthesised by [add_generic_instance] as + the projection of a parent class's instance via the subclass + chain, [tci_parents] gives the synthesised parent-instance paths + in the same order as the underlying TC's [tc_prts]. Empty for + manually-declared instances. Used by [resolve_lifted] to walk + the correct ancestor when multiple parent paths exist. *) + tci_parents : EcPath.path list; +} + +and tcibody = [ + | `Ring of ring + | `Field of field + | `General of typeclass * ((path * etyarg list) Mstr.t) option +] + +and thmode = [ `Abstract | `Concrete ] and rule_pattern = | Rule of top_rule_pattern * rule_pattern list @@ -53,7 +74,7 @@ and rule_pattern = | Var of EcIdent.t and top_rule_pattern = - [`Op of (EcPath.path * EcTypes.ty list) | `Tuple | `Proj of int] + [`Op of (EcPath.path * etyarg list) | `Tuple | `Proj of int] and rule = { rl_tyd : EcDecl.ty_params; diff --git a/src/ecTheory.mli b/src/ecTheory.mli index 20c34364b8..2a537771fe 100644 --- a/src/ecTheory.mli +++ b/src/ecTheory.mli @@ -1,5 +1,6 @@ (* -------------------------------------------------------------------- *) open EcSymbols +open EcMaps open EcPath open EcAst open EcTypes @@ -22,10 +23,11 @@ and theory_item_r = | Th_module of top_module_expr | Th_theory of (symbol * ctheory) | Th_export of EcPath.path * is_local - | Th_instance of (ty_params * EcTypes.ty) * tcinstance * is_local + | Th_instance of (symbol option * tcinstance) + | Th_typeclass of (symbol * tc_decl) | Th_baserw of symbol * is_local | Th_addrw of EcPath.path * EcPath.path list * is_local - (* reduction rule does not survive to section so no locality *) + (* reduction rule does not survive section => no locality *) | Th_reduction of (EcPath.path * rule_option * rule option) list | Th_auto of auto_rule | Th_alias of (symbol * path) @@ -41,8 +43,21 @@ and ctheory = { cth_source : thsource option; } -and tcinstance = [ `Ring of ring | `Field of field | `General of EcPath.path ] -and thmode = [ `Abstract | `Concrete ] +and tcinstance = { + tci_params : ty_params; + tci_type : ty; + tci_instance : tcibody; + tci_local : locality; + tci_parents : EcPath.path list; +} + +and tcibody = [ + | `Ring of ring + | `Field of field + | `General of typeclass * ((path * etyarg list) Mstr.t) option +] + +and thmode = [ `Abstract | `Concrete ] and rule_pattern = | Rule of top_rule_pattern * rule_pattern list @@ -50,7 +65,7 @@ and rule_pattern = | Var of EcIdent.t and top_rule_pattern = - [`Op of (EcPath.path * EcTypes.ty list) | `Tuple | `Proj of int] + [`Op of (EcPath.path * etyarg list) | `Tuple | `Proj of int] and rule = { rl_tyd : EcDecl.ty_params; diff --git a/src/ecTheoryReplay.ml b/src/ecTheoryReplay.ml index de60a2a630..033b5e0d09 100644 --- a/src/ecTheoryReplay.ml +++ b/src/ecTheoryReplay.ml @@ -1,5 +1,6 @@ (* ------------------------------------------------------------------ *) open EcSymbols +open EcMaps open EcUtils open EcLocation open EcParsetree @@ -50,283 +51,209 @@ let keep_of_mode (mode : clmode) = (* -------------------------------------------------------------------- *) exception Incompatible of incompatible +let tparams_compatible (rtyvars : ty_params) (ntyvars : ty_params) = + let rlen = List.length rtyvars and nlen = List.length ntyvars in + if rlen <> nlen then + raise (Incompatible (NotSameNumberOfTyParam (rlen, nlen))) + +let ty_compatible env ue (rtyvars, rty) (ntyvars, nty) = + tparams_compatible rtyvars ntyvars; + let subst = + let etyargs = etyargs_of_tparams ntyvars in + CS.Tvar.init (List.combine (List.fst rtyvars) etyargs) in + let rty = CS.Tvar.subst subst rty in + try EcUnify.unify env ue rty nty + with EcUnify.UnificationFailure _ -> + raise (Incompatible (DifferentType (rty, nty))) + (* -------------------------------------------------------------------- *) let error_body exn b = if not b then raise exn (* -------------------------------------------------------------------- *) -let get_open_tydecl (env : EcEnv.env) (p : EcPath.path) (tys : ty list) = - let tydecl = EcEnv.Ty.by_path p env in - EcSubst.open_tydecl tydecl tys +let ri_compatible = + { EcReduction.full_red with delta_p = (fun _-> `Force); user = false } (* -------------------------------------------------------------------- *) -exception CoreIncompatible +let constr_compatible exn env cs1 cs2 = + error_body exn (List.length cs1 = List.length cs2); + let doit (s1,tys1) (s2,tys2) = + error_body exn (EcSymbols.sym_equal s1 s2); + error_body exn (List.length tys1 = List.length tys2); + List.iter2 (fun ty1 ty2 -> error_body exn (EcReduction.EqTest.for_type env ty1 ty2)) tys1 tys2 in + List.iter2 doit cs1 cs2 + +let datatype_compatible exn hyps ty1 ty2 = + let env = EcEnv.LDecl.toenv hyps in + constr_compatible exn env ty1.tydt_ctors ty2.tydt_ctors + +let record_compatible exn hyps f1 pr1 f2 pr2 = + error_body exn (EcReduction.is_conv hyps f1 f2); + error_body exn (List.length pr1 = List.length pr2); + let env = EcEnv.LDecl.toenv hyps in + let doit (s1,ty1) (s2,ty2) = + error_body exn (EcSymbols.sym_equal s1 s2); + error_body exn (EcReduction.EqTest.for_type env ty1 ty2) in + List.iter2 doit pr1 pr2 + +let get_open_tydecl hyps p tys = + let tydecl = EcEnv.Ty.by_path p (EcEnv.LDecl.toenv hyps) in + EcSubst.open_tydecl tydecl tys -(* -------------------------------------------------------------------- *) -exception NoException +let rec tybody_compatible exn hyps ty_body1 ty_body2 = + match ty_body1, ty_body2 with + | `Abstract _, `Abstract _ -> () (* FIXME Sp.t *) + | `Concrete ty1 , `Concrete ty2 -> error_body exn (EcReduction.EqTest.for_type (EcEnv.LDecl.toenv hyps) ty1 ty2) + | `Datatype ty1 , `Datatype ty2 -> datatype_compatible exn hyps ty1 ty2 + | `Record (f1,pr1), `Record(f2,pr2) -> record_compatible exn hyps f1 pr1 f2 pr2 + | _, `Concrete {ty_node = Tconstr(p, tys) } -> + let ty_body2 = get_open_tydecl hyps p tys in + tybody_compatible exn hyps ty_body1 ty_body2 + | `Concrete{ty_node = Tconstr(p, tys) }, _ -> + let ty_body1 = get_open_tydecl hyps p tys in + tybody_compatible exn hyps ty_body1 ty_body2 + | _, _ -> raise exn (* FIXME should we do more for concrete version other *) + +let tydecl_compatible env tyd1 tyd2 = + let params = tyd1.tyd_params in + tparams_compatible params tyd2.tyd_params; + let tparams = etyargs_of_tparams params in + let ty_body1 = tyd1.tyd_type in + let ty_body2 = EcSubst.open_tydecl tyd2 tparams in + let exn = Incompatible (TyBody(*tyd1,tyd2*)) in + let hyps = EcEnv.LDecl.init env params in + match ty_body1, ty_body2 with + | `Abstract _, _ -> () (* FIXME Sp.t *) + | _, _ -> tybody_compatible exn hyps ty_body1 ty_body2 (* -------------------------------------------------------------------- *) -let get_open_oper (env : EcEnv.env) (p : EcPath.path) (tys : ty list) = +let expr_compatible exn env s e1 e2 = + let m = EcIdent.create "&hr" in + let f1 = EcFol.form_of_expr ~m e1 in + let f2 = EcSubst.subst_form s (EcFol.form_of_expr ~m e2) in + error_body exn (EcReduction.is_conv ~ri:ri_compatible (EcEnv.LDecl.init env []) f1 f2) + +let get_open_oper exn env p tys = let oper = EcEnv.Op.by_path p env in let _, okind = EcSubst.open_oper oper tys in match okind with | OB_oper (Some ob) -> ob - | _ -> raise CoreIncompatible - -(* -------------------------------------------------------------------- *) -let get_open_pred (env : EcEnv.env) (p : EcPath.path) (tys : ty list) = + | _ -> raise exn + +let rec oper_compatible exn env ob1 ob2 = + (* FIXME: duplicated code *) + match ob1, ob2 with + | OP_Plain f1, OP_Plain f2 -> + error_body exn (EcReduction.is_conv ~ri:ri_compatible (EcEnv.LDecl.init env []) f1 f2) + | OP_Plain {f_node = Fop(p,tys)}, _ -> + let ob1 = get_open_oper exn env p tys in + oper_compatible exn env ob1 ob2 + | _, OP_Plain {f_node = Fop(p,tys)} -> + let ob2 = get_open_oper exn env p tys in + oper_compatible exn env ob1 ob2 + | OP_Constr(p1,i1), OP_Constr(p2,i2) -> + error_body exn (EcPath.p_equal p1 p2 && i1 = i2) + | OP_Record p1, OP_Record p2 -> + error_body exn (EcPath.p_equal p1 p2) + | OP_Proj(p1,i11,i12), OP_Proj(p2,i21,i22) -> + error_body exn (EcPath.p_equal p1 p2 && i11 = i21 && i12 = i22) + | OP_Fix f1, OP_Fix f2 -> + opfix_compatible exn env f1 f2 + | OP_TC (p1, n1), OP_TC (p2, n2) -> + error_body exn (EcPath.p_equal p1 p2 && n1 = n2) + | _, _ -> raise exn + +and opfix_compatible exn env f1 f2 = + let s = params_compatible exn env EcSubst.empty f1.opf_args f2.opf_args in + error_body exn (EcReduction.EqTest.for_type env f1.opf_resty f2.opf_resty); + error_body exn (f1.opf_struct = f2.opf_struct); + opbranches_compatible exn env s f1.opf_branches f2.opf_branches + +and params_compatible exn env s p1 p2 = + error_body exn (List.length p1 = List.length p2); + let doit s (id1,ty1) (id2,ty2) = + error_body exn (EcReduction.EqTest.for_type env ty1 ty2); + EcSubst.add_flocal s id2 (EcFol.f_local id1 ty1) in + List.fold_left2 doit s p1 p2 + +and opbranches_compatible exn env s ob1 ob2 = + match ob1, ob2 with + | OPB_Leaf(d1,e1), OPB_Leaf(d2,e2) -> + error_body exn (List.length d1 = List.length d2); + let s = + List.fold_left2 (params_compatible exn env) s d1 d2 in + expr_compatible exn env s e1 e2 + + | OPB_Branch obs1, OPB_Branch obs2 -> + error_body exn (Parray.length obs1 = Parray.length obs2); + Parray.iter2 (opbranch_compatible exn env s) obs1 obs2 + | _, _ -> raise exn + +and opbranch_compatible exn env s ob1 ob2 = + error_body exn (EcPath.p_equal (fst ob1.opb_ctor) (fst ob2.opb_ctor)); + error_body exn (snd ob1.opb_ctor = snd ob2.opb_ctor); + opbranches_compatible exn env s ob1.opb_sub ob2.opb_sub + +let get_open_pred exn env p tys = let oper = EcEnv.Op.by_path p env in let _, okind = EcSubst.open_oper oper tys in match okind with | OB_pred (Some pb) -> pb - | _ -> raise CoreIncompatible - -(* -------------------------------------------------------------------- *) -module Compatible : sig - type 'a comparator = EcEnv.env -> 'a -> 'a -> unit - - val for_ty : - EcEnv.env - -> EcUnify.unienv - -> EcIdent.ident list * ty - -> EcIdent.ident list * ty - -> unit - - val for_tydecl : tydecl comparator - val for_operator : operator comparator -end = struct - open EcEnv.LDecl - - type 'a comparator = EcEnv.env -> 'a -> 'a -> unit - - let ri_compatible = - { EcReduction.full_red with delta_p = (fun _-> `Force); user = false } - - let check (b : bool) = - if not b then raise CoreIncompatible - - let for_tparams rtyvars ntyvars = - let rlen = List.length rtyvars - and nlen = List.length ntyvars in - - if rlen <> nlen then - raise (Incompatible (NotSameNumberOfTyParam (rlen, nlen))) - - let for_params - (hyps : hyps) - (s : EcSubst.subst) - (p1 : (EcIdent.ident * ty) list) - (p2 : (EcIdent.ident * ty) list) - = - check (List.compare_lengths p1 p2 = 0); - - let do_param s (id1, ty1) (id2, ty2) = - check (EcReduction.EqTest.for_type (toenv hyps) ty1 ty2); - EcSubst.add_flocal s id2 (EcFol.f_local id1 ty1) - in List.fold_left2 do_param s p1 p2 - - let for_ty (env : EcEnv.env) (ue : EcUnify.unienv) (rtyvars, rty) (ntyvars, nty) = - for_tparams rtyvars ntyvars; - - let subst = CS.Tvar.init rtyvars (List.map tvar ntyvars) in - let rty = CS.Tvar.subst subst rty in - - try EcUnify.unify env ue rty nty - with EcUnify.UnificationFailure _ -> - raise (Incompatible - (DifferentType (rty, nty))) - - let for_expr (hyps : hyps) (s : EcSubst.subst) (e1 : expr) (e2 : expr) = - let f1 = EcFol.form_of_expr e1 in - let f2 = EcSubst.subst_form s (EcFol.form_of_expr e2) in - check (EcReduction.is_conv ~ri:ri_compatible hyps f1 f2) - - let for_datatype (hyps : hyps) (ty1 : ty_dtype) (ty2 : ty_dtype) = - let for_constr (cs1 : ty_dtype_ctor list) (cs2 : ty_dtype_ctor list) = - check (List.compare_lengths cs1 cs2 = 0); - - let for_ctor1 (s1,tys1) (s2,tys2) = - check (EcSymbols.sym_equal s1 s2); - check (List.compare_lengths tys1 tys2 = 0); - List.iter2 (fun ty1 ty2 -> - check (EcReduction.EqTest.for_type (toenv hyps) ty1 ty2) - ) tys1 tys2 - in List.iter2 for_ctor1 cs1 cs2 - in for_constr ty1.tydt_ctors ty2.tydt_ctors - - let for_record (hyps : hyps) ((f1, pr1) : ty_record) ((f2, pr2) : ty_record) = - check (EcReduction.is_conv hyps f1 f2); - - let for_field (s1, ty1) (s2, ty2) = - check (EcSymbols.sym_equal s1 s2); - check (EcReduction.EqTest.for_type (toenv hyps) ty1 ty2) - in List.iter2 for_field pr1 pr2 - - let rec tybody (hyps : EcEnv.LDecl.hyps) (ty_body1 : ty_body) (ty_body2 : ty_body) = - match ty_body1, ty_body2 with - | Abstract , Abstract -> () - | Concrete ty1 , Concrete ty2 -> check (EcReduction.EqTest.for_type (toenv hyps) ty1 ty2) - | Datatype ty1 , Datatype ty2 -> for_datatype hyps ty1 ty2 - | Record rec1, Record rec2 -> for_record hyps rec1 rec2 - - | _, Concrete { ty_node = Tconstr (p, tys) } -> - let ty_body2 = get_open_tydecl (toenv hyps) p tys in - tybody hyps ty_body1 ty_body2 - - | Concrete{ ty_node = Tconstr (p, tys) }, _ -> - let ty_body1 = get_open_tydecl (toenv hyps) p tys in - tybody hyps ty_body1 ty_body2 - - | _, _ -> raise CoreIncompatible - - let for_tydecl (env : EcEnv.env) (tyd1 : tydecl) (tyd2 : tydecl) = - try - let params = tyd1.tyd_params in - - for_tparams params tyd2.tyd_params; - - let tparams = List.map tvar params in - let ty_body1 = tyd1.tyd_type in - let ty_body2 = EcSubst.open_tydecl tyd2 tparams in - - let hyps = EcEnv.LDecl.init env params in - - match ty_body1, ty_body2 with - | Abstract, _ -> () - - | _, _ -> tybody hyps ty_body1 ty_body2 - - with CoreIncompatible -> raise (Incompatible TyBody) - - - let for_opfix (hyps : hyps) (f1 : opfix) (f2 : opfix) = - let rec for_opbranch (s : EcSubst.subst) (ob1 : opbranch) (ob2 : opbranch) = - check (EcPath.p_equal (fst ob1.opb_ctor) (fst ob2.opb_ctor)); - check (snd ob1.opb_ctor = snd ob2.opb_ctor); - for_opbranches hyps s ob1.opb_sub ob2.opb_sub - - and for_opbranches (hyps : hyps) (s : EcSubst.subst) (ob1 : opbranches) (ob2 : opbranches) = - match ob1, ob2 with - | OPB_Leaf (d1, e1), OPB_Leaf (d2, e2) -> - check (List.compare_lengths d1 d2 = 0); - let s = List.fold_left2 (for_params hyps) s d1 d2 in - for_expr hyps s e1 e2 - - | OPB_Branch obs1, OPB_Branch obs2 -> - check (Parray.length obs1 = Parray.length obs2); - Parray.iter2 (for_opbranch s) obs1 obs2 - - | _, _ -> raise CoreIncompatible - in - - check (EcReduction.EqTest.for_type (toenv hyps) f1.opf_resty f2.opf_resty); - check (f1.opf_struct = f2.opf_struct); - - let s = for_params hyps EcSubst.empty f1.opf_args f2.opf_args in - let s = EcSubst.add_path ~src:f2.opf_recp ~dst:f1.opf_recp s - - in for_opbranches hyps s f1.opf_branches f2.opf_branches - - let for_ind (hyps : hyps) (pi1 : prind) (pi2 : prind) = - let for_prctor (s : EcSubst.subst) (prc1 : prctor) (prc2 : prctor) = - check (EcSymbols.sym_equal prc1.prc_ctor prc2.prc_ctor); - let (env, s) = - EcReduction.check_bindings - CoreIncompatible (toenv hyps) s prc1.prc_bds prc2.prc_bds in - let hyps = EcEnv.LDecl.init env [] in - check (List.compare_lengths prc1.prc_spec prc2.prc_spec = 0); - let for_spec (f1 : form) (f2 : form) = - check (EcReduction.is_conv hyps f1 (EcSubst.subst_form s f2)) in - List.iter2 for_spec prc1.prc_spec prc2.prc_spec - in - - let s = for_params hyps EcSubst.empty pi1.pri_args pi2.pri_args in - check (List.compare_lengths pi1.pri_ctors pi2.pri_ctors = 0); - List.iter2 (for_prctor s) pi1.pri_ctors pi2.pri_ctors - - let rec for_oper (hyps : hyps) (ob1 : opbody) (ob2 : opbody) = - match ob1, ob2 with - | OP_Plain f1, OP_Plain f2 -> - check (EcReduction.is_conv ~ri:ri_compatible hyps f1 f2) - - | OP_Plain { f_node = Fop (p, tys) }, _ -> - let ob1 = get_open_oper (toenv hyps) p tys in - for_oper hyps ob1 ob2 - - | _, OP_Plain { f_node = Fop (p, tys) } -> - let ob2 = get_open_oper (toenv hyps) p tys in - for_oper hyps ob1 ob2 - - | OP_Constr (p1, i1), OP_Constr (p2, i2) -> - check (EcPath.p_equal p1 p2 && i1 = i2) - - | OP_Record p1, OP_Record p2 -> - check (EcPath.p_equal p1 p2) - - | OP_Proj (p1, i11, i12), OP_Proj (p2, i21, i22) -> - check (EcPath.p_equal p1 p2 && i11 = i21 && i12 = i22) - - | OP_Fix f1, OP_Fix f2 -> - for_opfix hyps f1 f2 - - | OP_TC, OP_TC -> () - - | OP_Exn _, OP_Exn _ -> raise NoException - (* Replacing exception during cloning is not allowed *) - - | _, _ -> raise CoreIncompatible - - let rec for_pred (hyps : EcEnv.LDecl.hyps) (pb1 : prbody) (pb2 : prbody) = - match pb1, pb2 with - | PR_Plain f1, PR_Plain f2 -> - check (EcReduction.is_conv hyps f1 f2) - - | PR_Plain { f_node = Fop (p, tys) }, _ -> - let pb1 = get_open_pred (toenv hyps) p tys in - for_pred hyps pb1 pb2 - - | _, PR_Plain { f_node = Fop (p, tys) } -> - let pb2 = get_open_pred (toenv hyps) p tys in - for_pred hyps pb1 pb2 - - | PR_Ind pr1, PR_Ind pr2 -> - for_ind hyps pr1 pr2 - - | _, _ -> raise CoreIncompatible - - let for_nott (hyps : hyps) (nb1 : notation) (nb2 : notation) = - let s = for_params hyps EcSubst.empty nb1.ont_args nb2.ont_args in - (* We do not check ont_resty because it is redundant *) - for_expr hyps s nb1.ont_body nb2.ont_body - - let for_operator (env : EcEnv.env) (oper1 : operator) (oper2 : operator) = - let params = oper1.op_tparams in - - for_tparams oper1.op_tparams oper2.op_tparams; - - let oty1, okind1 = oper1.op_ty, oper1.op_kind in - let tparams = List.map tvar params in - let oty2, okind2 = EcSubst.open_oper oper2 tparams in - - if not (EcReduction.EqTest.for_type env oty1 oty2) then - raise (Incompatible (DifferentType(oty1, oty2))); - - let hyps = EcEnv.LDecl.init env params in - - try - match okind1, okind2 with - | OB_oper None , OB_oper _ -> () - | OB_oper (Some ob1), OB_oper (Some ob2) -> for_oper hyps ob2 ob1 - | OB_pred None , OB_pred _ -> () - | OB_pred (Some pb1), OB_pred (Some pb2) -> for_pred hyps pb2 pb1 - | OB_nott nb1 , OB_nott nb2 -> for_nott hyps nb2 nb1 - | _ , _ -> raise (Incompatible OpBody) - - with Failure _ -> raise (Incompatible OpBody) -end + | _ -> raise exn + +let rec pred_compatible exn env pb1 pb2 = + match pb1, pb2 with + | PR_Plain f1, PR_Plain f2 -> error_body exn (EcReduction.is_conv (EcEnv.LDecl.init env []) f1 f2) + | PR_Plain {f_node = Fop(p,tys)}, _ -> + let pb1 = get_open_pred exn env p tys in + pred_compatible exn env pb1 pb2 + | _, PR_Plain {f_node = Fop(p,tys)} -> + let pb2 = get_open_pred exn env p tys in + pred_compatible exn env pb1 pb2 + | PR_Ind pr1, PR_Ind pr2 -> + ind_compatible exn env pr1 pr2 + | _, _ -> raise exn + +and ind_compatible exn env pi1 pi2 = + let s = params_compatible exn env EcSubst.empty pi1.pri_args pi2.pri_args in + error_body exn (List.length pi1.pri_ctors = List.length pi2.pri_ctors); + List.iter2 (prctor_compatible exn env s) pi1.pri_ctors pi2.pri_ctors + +and prctor_compatible exn env s prc1 prc2 = + error_body exn (EcSymbols.sym_equal prc1.prc_ctor prc2.prc_ctor); + let env, s = EcReduction.check_bindings exn env s prc1.prc_bds prc2.prc_bds in + error_body exn (List.length prc1.prc_spec = List.length prc2.prc_spec); + let doit f1 f2 = + error_body exn (EcReduction.is_conv (EcEnv.LDecl.init env []) f1 (EcSubst.subst_form s f2)) in + List.iter2 doit prc1.prc_spec prc2.prc_spec + +let nott_compatible exn env nb1 nb2 = + let s = params_compatible exn env EcSubst.empty nb1.ont_args nb2.ont_args in + (* We do not check ont_resty because it is redundant *) + expr_compatible exn env s nb1.ont_body nb2.ont_body + +let operator_compatible env oper1 oper2 = + let open EcDecl in + let params = oper1.op_tparams in + tparams_compatible oper1.op_tparams oper2.op_tparams; + let oty1, okind1 = oper1.op_ty, oper1.op_kind in + let tparams = etyargs_of_tparams params in + let oty2, okind2 = EcSubst.open_oper oper2 tparams in + if not (EcReduction.EqTest.for_type env oty1 oty2) then + raise (Incompatible (DifferentType(oty1, oty2))); + let hyps = EcEnv.LDecl.init env params in + let env = EcEnv.LDecl.toenv hyps in + let exn = Incompatible (OpBody(*oper1,oper2*)) in + match okind1, okind2 with + | OB_oper None , OB_oper _ -> () + | OB_oper (Some ob1), OB_oper (Some ob2) -> oper_compatible exn env ob2 ob1 + | OB_pred None , OB_pred _ -> () + | OB_pred (Some pb1), OB_pred (Some pb2) -> pred_compatible exn env pb2 pb1 + | OB_nott nb1 , OB_nott nb2 -> nott_compatible exn env nb2 nb1 + | _ , _ -> raise exn (* -------------------------------------------------------------------- *) let check_evtags ?(tags : evtags option) (src : symbol list) = - let exception Reject in + let module E = struct exception Reject end in let explicit = "explicit" in @@ -334,7 +261,7 @@ let check_evtags ?(tags : evtags option) (src : symbol list) = match tags with | None -> if List.mem explicit src then - raise Reject; + raise E.Reject; true | Some tags -> @@ -345,13 +272,13 @@ let check_evtags ?(tags : evtags option) (src : symbol list) = List.map (fun src -> let do1 status (mode, dst) = match mode with - | `Exclude -> if sym_equal src dst then raise Reject; status + | `Exclude -> if sym_equal src dst then raise E.Reject; status | `Include -> status || (sym_equal src dst) in List.fold_left do1 dfl tags) src in List.mem true stt - with Reject -> false + with E.Reject -> false (* -------------------------------------------------------------------- *) let xpath ove x = @@ -364,16 +291,15 @@ let xnpath ove x = (EcPath.fromqsymbol (snd ove.ovre_prefix, x)) (* -------------------------------------------------------------------- *) -let string_of_renaming_kind (rkind : theory_renaming_kind) = - match rkind with +let string_of_renaming_kind = function | `Lemma -> "lemma" | `Op -> "operator" | `Pred -> "predicate" | `Type -> "type" + | `Exn -> "exception" | `Module -> "module" | `ModType -> "module type" | `Theory -> "theory" - | `Exn -> "exception" (* -------------------------------------------------------------------- *) let rename ove subst (kind, name) = @@ -387,9 +313,9 @@ let rename ove subst (kind, name) = let nameok = match kind with - | `Lemma | `Type -> + | `Lemma | `Type | `Exn -> EcIo.is_sym_ident newname - | `Op | `Pred | `Exn -> + | `Op | `Pred -> EcIo.is_op_ident newname | `Module | `ModType | `Theory -> EcIo.is_mod_ident newname @@ -423,38 +349,43 @@ let rec replay_tyd (ove : _ ovrenv) (subst, ops, proofs, scope) (import, x, otyd let newtyd, body = match tydov with | `BySyntax (nargs, ntyd) -> - let nargs = List.map - (fun x -> (EcIdent.create (unloc x))) - nargs in + let nargs = List.map2 + (fun (_, tc) x -> (EcIdent.create (unloc x), tc)) + otyd.tyd_params nargs in let ue = EcUnify.UniEnv.create (Some nargs) in let ntyd = EcTyping.transty EcTyping.tp_tydecl env ue ntyd in let decl = { tyd_params = nargs; - tyd_type = Concrete ntyd; - tyd_loca = otyd.tyd_loca; } + tyd_type = `Concrete ntyd; + tyd_resolve = otyd.tyd_resolve && (mode = `Alias); + tyd_loca = otyd.tyd_loca; + tyd_subtype = None; } in (decl, ntyd) | `ByPath p -> begin match EcEnv.Ty.by_path_opt p env with | Some reftyd -> - let tyargs = List.map tvar reftyd.tyd_params in + let tyargs = List.map (fun (x, _) -> tvar x) reftyd.tyd_params in let body = tconstr p tyargs in - let decl = { reftyd with tyd_type = Concrete body; } in + let decl = + { reftyd with + tyd_type = `Concrete body; + tyd_resolve = otyd.tyd_resolve && (mode = `Alias); } in (decl, body) | _ -> assert false end - | `Direct ty -> begin - assert (List.is_empty otyd.tyd_params); - let decl = - { tyd_params = []; - tyd_type = Concrete ty; - tyd_loca = otyd.tyd_loca; } - - in (decl, ty) - end + | `Direct ty -> + assert (List.is_empty otyd.tyd_params); + let decl = + { tyd_params = []; + tyd_type = `Concrete ty; + tyd_resolve = otyd.tyd_resolve && (mode = `Alias); + tyd_loca = otyd.tyd_loca; + tyd_subtype = None; } + in (decl, ty) in let subst, x = @@ -464,30 +395,43 @@ let rec replay_tyd (ove : _ ovrenv) (subst, ops, proofs, scope) (import, x, otyd | `Inline _ -> let subst = + (* When [otyd] is [`Abstract tcs] (the cloned source was + [type t <: tc1 <: tc2 …]), we need one [tcwitness] per + TC entry, looked up in the instance database for + [body]. Without these, [`Abs t_path; offset; lift] + witnesses inside cloned axioms would rewrite to + [`Abs body; offset; lift] referencing TC slots [body] + doesn't have. [witnesses_for_body] queries each + via [EcTypeClass.infer]; for non-TC clones (the + common stdlib case) [tcs = []] and the result is just + []. *) + let bodytcs = + match otyd.tyd_type with + | `Abstract tcs -> + EcTypeClass.witnesses_for_body env body tcs + | _ -> [] in EcSubst.add_tydef - subst (xpath ove x) (newtyd.tyd_params, body) in + subst (xpath ove x) + (List.map fst newtyd.tyd_params, body, bodytcs) in let subst = (* FIXME: HACK *) match otyd.tyd_type, body.ty_node with - | Datatype { tydt_ctors = octors }, Tconstr (np, _) -> begin + | `Datatype { tydt_ctors = octors }, Tconstr (np, _) -> begin match (EcEnv.Ty.by_path np env).tyd_type with - | Datatype { tydt_ctors = _ } -> - let newtparams = newtyd.tyd_params in - let newtparams_ty = List.map tvar newtparams in - let newdtype = tconstr np newtparams_ty in - let tysubst = CS.Tvar.init otyd.tyd_params newtparams_ty in + | `Datatype { tydt_ctors = _ } -> + let newtparams = etyargs_of_tparams newtyd.tyd_params in + let newdtype = tconstr_tc np newtparams in + let tysubst = + CS.Tvar.init (List.combine (List.fst otyd.tyd_params) newtparams) in List.fold_left (fun subst (name, tyargs) -> - let np = EcPath.pqoname (EcPath.prefix np) name in - let newtyargs = - List.map - (CS.Tvar.subst tysubst -| EcSubst.subst_ty subst) - tyargs in - EcSubst.add_opdef subst - (xpath ove name) - (newtparams, e_op np newtparams_ty (toarrow newtyargs newdtype)) - ) subst octors + let np = EcPath.pqoname (EcPath.prefix np) name in + let newtyargs = List.map (CS.Tvar.subst tysubst) tyargs in + EcSubst.add_opdef subst + (xpath ove name) + (List.fst newtyd.tyd_params, e_op_tc np newtparams (toarrow newtyargs newdtype))) + subst octors | _ -> subst end | _, _ -> subst @@ -497,7 +441,7 @@ let rec replay_tyd (ove : _ ovrenv) (subst, ops, proofs, scope) (import, x, otyd let refotyd = EcSubst.subst_tydecl subst otyd in begin - try Compatible.for_tydecl env refotyd newtyd + try tydecl_compatible env refotyd newtyd with Incompatible err -> clone_error env (CE_TyIncompatible ((snd ove.ovre_prefix, x), err)) end; @@ -523,14 +467,9 @@ and replay_opd (ove : _ ovrenv) (subst, ops, proofs, scope) (import, x, oopd) = let scenv = ove.ovre_hooks.henv scope in let env = EcSection.env scenv in - let rk = - match oopd.op_kind with - | OB_oper (Some (OP_Exn _)) -> `Exn - | _ -> `Op in - match Msym.find_opt x ove.ovre_ovrd.evc_ops with | None -> - let (subst, x) = rename ove subst (rk, x) in + let (subst, x) = rename ove subst (`Op, x) in let oopd = EcSubst.subst_op subst oopd in (subst, ops, proofs, ove.ovre_hooks.hadd_item scope ~import (Th_operator (x, oopd))) @@ -542,7 +481,7 @@ and replay_opd (ove : _ ovrenv) (subst, ops, proofs, scope) (import, x, oopd) = let newop, body = match opov with | `BySyntax opov -> - let tp = opov.opov_tyvars in + let tp = opov.opov_tyvars |> omap (List.map (fun tv -> (tv, []))) in let ue = EcTyping.transtyvars env (loc, tp) in let tp = EcTyping.tp_relax in let (ty, body) = @@ -553,16 +492,19 @@ and replay_opd (ove : _ ovrenv) (subst, ops, proofs, scope) (import, x, oopd) = (lam.f_ty, lam) in begin - try Compatible.for_ty env ue + try ty_compatible env ue (reftyvars, refty) (EcUnify.UniEnv.tparams ue, ty) with Incompatible err -> clone_error env (CE_OpIncompatible ((snd ove.ovre_prefix, x), err)) end; - if not (EcUnify.UniEnv.closed ue) then - ove.ovre_hooks.herr - ~loc "this operator body contains free type variables"; + Option.iter (fun infos -> + ove.ovre_hooks.herr ~loc + (Format.asprintf + "this operator body contains free %a variables" + EcUserMessages.TypingError.pp_uniflags infos) + ) (EcUnify.UniEnv.xclosed ue); let sty = CS.Tuni.subst (EcUnify.UniEnv.close ue) in let body = EcFol.Fsubst.f_subst sty body in @@ -577,7 +519,7 @@ and replay_opd (ove : _ ovrenv) (subst, ops, proofs, scope) (import, x, oopd) = | `ByPath p -> begin match EcEnv.Op.by_path_opt p env with | Some ({ op_kind = OB_oper _ } as refop) -> - let tyargs = List.map tvar refop.op_tparams in + let tyargs = List.map (fun (x, _) -> tvar x) refop.op_tparams in let body = if refop.op_clinline then (match refop.op_kind with @@ -594,17 +536,14 @@ and replay_opd (ove : _ ovrenv) (subst, ops, proofs, scope) (import, x, oopd) = end | `Direct body -> - assert (List.is_empty refop.op_tparams); - let newop = - mk_op - ~opaque:optransparent ~clinline:(opmode <> `Alias) - [] body.f_ty (Some (OP_Plain body)) refop.op_loca in - (newop, body) - + let newop = + mk_op ~opaque:optransparent ~clinline:(opmode <> `Alias) + refop.op_tparams body.f_ty (Some (OP_Plain body)) refop.op_loca in + (newop, body) in match opmode with | `Alias -> - let subst, x = rename ove subst (rk, x) in + let subst, x = rename ove subst (`Op, x) in (newop, subst, x, true) | `Inline _ -> @@ -614,7 +553,7 @@ and replay_opd (ove : _ ovrenv) (subst, ops, proofs, scope) (import, x, oopd) = with EcFol.CannotTranslate -> clone_error env (CE_InlinedOpIsForm (snd ove.ovre_prefix, x)) in - let subst1 = (newop.op_tparams, body) in + let subst1 = (List.map fst newop.op_tparams, body) in let subst = EcSubst.add_opdef subst (xpath ove x) subst1 in (newop, subst, x, false) in @@ -624,11 +563,8 @@ and replay_opd (ove : _ ovrenv) (subst, ops, proofs, scope) (import, x, oopd) = Mp.add opp (newop, alias) ops in begin - try Compatible.for_operator env refop newop - with - | NoException -> - clone_error env (CE_NoExceptions) - | Incompatible err -> + try operator_compatible env refop newop + with Incompatible err -> clone_error env (CE_OpIncompatible ((snd ove.ovre_prefix, x), err)) end; @@ -667,7 +603,7 @@ and replay_prd (ove : _ ovrenv) (subst, ops, proofs, scope) (import, x, oopr) = let newpr, body = match prov with | `BySyntax prov -> - let tp = prov.prov_tyvars in + let tp = prov.prov_tyvars |> omap (List.map (fun tv -> (tv, []))) in let ue = EcTyping.transtyvars env (loc, tp) in let body = let env, xs = EcTyping.trans_binding env ue prov.prov_args in @@ -679,7 +615,7 @@ and replay_prd (ove : _ ovrenv) (subst, ops, proofs, scope) (import, x, oopr) = begin try - Compatible.for_ty env ue + ty_compatible env ue (reftyvars, refty) (EcUnify.UniEnv.tparams ue, body.f_ty) with Incompatible err -> @@ -687,9 +623,12 @@ and replay_prd (ove : _ ovrenv) (subst, ops, proofs, scope) (import, x, oopr) = (CE_OpIncompatible ((snd ove.ovre_prefix, x), err)) end; - if not (EcUnify.UniEnv.closed ue) then - ove.ovre_hooks.herr - ~loc "this predicate body contains free type variables"; + Option.iter (fun infos -> + ove.ovre_hooks.herr ~loc + (Format.asprintf + "this predicate body contains free %a variables" + EcUserMessages.TypingError.pp_uniflags infos) + ) (EcUnify.UniEnv.xclosed ue); let fs = CS.Tuni.subst (EcUnify.UniEnv.close ue) in let body = EcFol.Fsubst.f_subst fs body in @@ -707,7 +646,7 @@ and replay_prd (ove : _ ovrenv) (subst, ops, proofs, scope) (import, x, oopr) = | `ByPath p -> begin match EcEnv.Op.by_path_opt p env with | Some ({ op_kind = OB_pred _ } as refop) -> - let tyargs = List.map tvar refop.op_tparams in + let tyargs = List.map (fun (x, _) -> tvar x) refop.op_tparams in let body = if refop.op_clinline then (match refop.op_kind with @@ -724,16 +663,12 @@ and replay_prd (ove : _ ovrenv) (subst, ops, proofs, scope) (import, x, oopr) = end | `Direct body -> - assert (List.is_empty refpr.op_tparams); - let newpr = - { op_tparams = []; - op_ty = body.f_ty; - op_kind = OB_pred (Some (PR_Plain body)); - op_opaque = oopr.op_opaque; - op_clinline = prmode <> `Alias; - op_loca = refpr.op_loca; - op_unfold = refpr.op_unfold; } in - (newpr, body) + let newpr = + { refpr with + op_kind = OB_pred (Some (PR_Plain body)); + op_ty = body.f_ty; + op_clinline = (prmode <> `Alias); } + in (newpr, body) in match prmode with @@ -742,14 +677,14 @@ and replay_prd (ove : _ ovrenv) (subst, ops, proofs, scope) (import, x, oopr) = (newpr, subst, x) | `Inline _ -> - let subst1 = (newpr.op_tparams, body) in + let subst1 = (List.map fst newpr.op_tparams, body) in let subst = EcSubst.add_pddef subst (xpath ove x) subst1 in (newpr, subst, x) in begin - try Compatible.for_operator env refpr newpr + try operator_compatible env refpr newpr with Incompatible err -> clone_error env (CE_OpIncompatible ((snd ove.ovre_prefix, x), err)) end; @@ -811,13 +746,9 @@ and replay_axd (ove : _ ovrenv) (subst, ops, proofs, scope) (import, x, ax) = match Msym.find_opt x (ove.ovre_ovrd.evc_lemmas.ev_bynames) with | Some (pt, hide, explicit) -> Some (pt, hide, explicit) | None when is_axiom ax.ax_kind -> - List.Exceptionless.find_map (function - | (pt, None) -> - if check_evtags (Ssym.elements tags) then - Some (pt, `Alias, false) - else None - | (pt, Some pttags) -> - if check_evtags ~tags:pttags (Ssym.elements tags) then + List.Exceptionless.find_map + (fun (pt, pttags) -> + if check_evtags ?tags:pttags (Ssym.elements tags) then Some (pt, `Alias, false) else None) ove.ovre_glproof @@ -828,7 +759,10 @@ and replay_axd (ove : _ ovrenv) (subst, ops, proofs, scope) (import, x, ax) = | Some (pt, hide, explicit) -> if explicit && not (EcDecl.is_axiom ax.ax_kind) then clone_error (EcSection.env scenv) (CE_ProofForLemma (snd ove.ovre_prefix, x)); - let ax = { ax with ax_kind = `Lemma } in + let ax = { ax with + ax_kind = `Lemma; + ax_smt = if hide <> `Alias then false else ax.ax_smt + } in let axc = { axc_axiom = (x, ax); axc_path = EcPath.fromqsymbol (snd ove.ovre_prefix, x); axc_tac = pt; @@ -877,11 +811,54 @@ and replay_modtype and replay_mod (ove : _ ovrenv) (subst, ops, proofs, scope) (import, (me : top_module_expr)) = - let subst, name = rename ove subst (`Module, me.tme_expr.me_name) in - let me = EcSubst.subst_top_module subst me in - let me = { me with tme_expr = { me.tme_expr with me_name = name } } in - let item = (Th_module me) in - (subst, ops, proofs, ove.ovre_hooks.hadd_item scope ~import item) + match Msym.find_opt me.tme_expr.me_name ove.ovre_ovrd.evc_modexprs with + | None -> + let subst, name = rename ove subst (`Module, me.tme_expr.me_name) in + let me = EcSubst.subst_top_module subst me in + let me = { me with tme_expr = { me.tme_expr with me_name = name } } in + let item = (Th_module me) in + (subst, ops, proofs, ove.ovre_hooks.hadd_item scope ~import item) + + | Some { pl_desc = (newname, mode) } -> + let name = me.tme_expr.me_name in + let env = EcSection.env (ove.ovre_hooks.henv scope) in + + let mp, (newme, newlc) = EcEnv.Mod.lookup (unloc newname) env in + + let substme = EcSubst.add_moddef subst ~src:(xpath ove name) ~dst:mp in + + let me = EcSubst.subst_top_module substme me in + let me = { me with tme_expr = { me.tme_expr with me_name = name } } in + let newme = { newme with me_name = name } in + let newme = { tme_expr = newme; tme_loca = Option.get newlc; } in + + if not (EcReduction.EqTest.for_mexpr ~body:false env me.tme_expr newme.tme_expr) then + clone_error env (CE_ModIncompatible (snd ove.ovre_prefix, name)); + + let subst = + match mode with + | `Alias -> + fst (rename ove subst (`Module, name)) + | `Inline _ -> + substme in + + let newme = + if mode = `Alias || mode = `Inline `Keep then + let alias = ME_Alias ( + List.length newme.tme_expr.me_params, + EcPath.m_apply + mp + (List.map (fun (id, _) -> EcPath.mident id) newme.tme_expr.me_params) + ) + in { newme with tme_expr = { newme.tme_expr with me_body = alias } } + else newme in + + let scope = + if keep_of_mode mode + then ove.ovre_hooks.hadd_item scope ~import (Th_module newme) + else scope in + + (subst, ops, proofs, scope) (* -------------------------------------------------------------------- *) and replay_export @@ -915,12 +892,12 @@ and replay_addrw (* -------------------------------------------------------------------- *) and replay_auto - (ove : _ ovrenv) (subst, ops, proofs, scope) (import, at_base) + (ove : _ ovrenv) (subst, ops, proofs, scope) (import, lvl, base, ps, lc) = let env = EcSection.env (ove.ovre_hooks.henv scope) in - let axioms = List.map (fst_map (EcSubst.subst_path subst)) at_base.axioms in - let axioms = List.filter (fun (p, _) -> Option.is_some (EcEnv.Ax.by_path_opt p env)) axioms in - let scope = ove.ovre_hooks.hadd_item scope ~import (Th_auto { at_base with axioms }) in + let ps = List.map (fun (p, k) -> (EcSubst.subst_path subst p, k)) ps in + let ps = List.filter (fun (p, _) -> Option.is_some (EcEnv.Ax.by_path_opt p env)) ps in + let scope = ove.ovre_hooks.hadd_item scope ~import (Th_auto { level = lvl; base; axioms = ps; locality = lc }) in (subst, ops, proofs, scope) (* -------------------------------------------------------------------- *) @@ -950,9 +927,17 @@ and replay_reduction (subst, ops, proofs, scope) +(* -------------------------------------------------------------------- *) +and replay_typeclass + (ove : _ ovrenv) (subst, ops, proofs, scope) (import, x, tc) += + let tc = EcSubst.subst_tc subst tc in + let scope = ove.ovre_hooks.hadd_item scope ~import (Th_typeclass (x, tc)) in + (subst, ops, proofs, scope) + (* -------------------------------------------------------------------- *) and replay_instance - (ove : _ ovrenv) (subst, ops, proofs, scope) (import, (typ, ty), tc, lc) + (ove : _ ovrenv) (subst, ops, proofs, scope) (import, x, tci) = let opath = ove.ovre_opath in let npath = ove.ovre_npath in @@ -980,8 +965,8 @@ and replay_instance | OB_oper (Some (OP_Record _)) | OB_oper (Some (OP_Proj _)) | OB_oper (Some (OP_Fix _)) - | OB_oper (Some (OP_Exn _)) - | OB_oper (Some (OP_TC )) -> + | OB_oper (Some (OP_TC _)) + | OB_oper (Some (OP_Exn _)) -> Some (EcPath.pappend npath q) | OB_oper (Some (OP_Plain f)) -> match f.f_node with @@ -991,9 +976,15 @@ and replay_instance let forpath p = odfl p (forpath p) in + let fortypeclass (tc : typeclass) = + { tc_name = forpath tc.tc_name; + tc_args = List.map (EcSubst.subst_etyarg subst) tc.tc_args; } in + try - let (typ, ty) = EcSubst.subst_genty subst (typ, ty) in - let tc = + let subst, tci_params = EcSubst.fresh_tparams subst tci.tci_params in + let tci_type = EcSubst.subst_ty subst tci.tci_type in + + let tci_instance : tcibody = let rec doring cr = { r_type = EcSubst.subst_ty subst cr.r_type; r_zero = forpath cr.r_zero; @@ -1016,93 +1007,87 @@ and replay_instance f_inv = forpath cr.f_inv; f_div = cr.f_div |> omap forpath; } in - match tc with - | `Ring cr -> `Ring (doring cr) - | `Field cr -> `Field (dofield cr) - | `General p -> `General (forpath p) + match tci.tci_instance with + | `Ring cr -> `Ring (doring cr) + | `Field cr -> `Field (dofield cr) + + | `General (tc, syms) -> + let tc = fortypeclass tc in + let syms = + Option.map + (Mstr.map (fun (p, tys) -> + (forpath p, List.map (EcSubst.subst_etyarg subst) tys))) + syms in + `General (tc, syms) in - let scope = ove.ovre_hooks.hadd_item scope ~import (Th_instance ((typ, ty), tc, lc)) in - (subst, ops, proofs, scope) + let tci = { tci with tci_params; tci_type; tci_instance; } in - with E.InvInstPath -> - (subst, ops, proofs, scope) - -(* -------------------------------------------------------------------- *) -and replay_alias - (ove : _ ovrenv) (subst, ops, proofs, scope) (import, name, target) -= - let scenv = ove.ovre_hooks.henv scope in - let env = EcSection.env scenv in - let p = EcSubst.subst_path subst target in + let scope = + ove.ovre_hooks.hadd_item scope ~import (Th_instance (x, tci)) + in (subst, ops, proofs, scope) - if is_none (EcEnv.Theory.by_path_opt p env) then - (subst, ops, proofs, scope) - else - let scope = ove.ovre_hooks.hadd_item scope ~import (Th_alias (name, target)) in + with E.InvInstPath -> (subst, ops, proofs, scope) (* -------------------------------------------------------------------- *) -and replay1 (ove : _ ovrenv) (subst, ops, proofs, scope) (hidden, item) = - let import = not hidden && item.ti_import in - +and replay1 (ove : _ ovrenv) (subst, ops, proofs, scope) item = match item.ti_item with | Th_type (x, otyd) -> - replay_tyd ove (subst, ops, proofs, scope) (import, x, otyd) + replay_tyd ove (subst, ops, proofs, scope) (item.ti_import, x, otyd) | Th_operator (x, ({ op_kind = OB_oper _ } as oopd)) -> - replay_opd ove (subst, ops, proofs, scope) (import, x, oopd) + replay_opd ove (subst, ops, proofs, scope) (item.ti_import, x, oopd) | Th_operator (x, ({ op_kind = OB_pred _} as oopr)) -> - replay_prd ove (subst, ops, proofs, scope) (import, x, oopr) + replay_prd ove (subst, ops, proofs, scope) (item.ti_import, x, oopr) | Th_operator (x, ({ op_kind = OB_nott _} as oont)) -> - replay_ntd ove (subst, ops, proofs, scope) (import, x, oont) + replay_ntd ove (subst, ops, proofs, scope) (item.ti_import, x, oont) | Th_axiom (x, ax) -> - replay_axd ove (subst, ops, proofs, scope) (import, x, ax) + replay_axd ove (subst, ops, proofs, scope) (item.ti_import, x, ax) | Th_modtype (x, modty) -> - replay_modtype ove (subst, ops, proofs, scope) (import, x, modty) + replay_modtype ove (subst, ops, proofs, scope) (item.ti_import, x, modty) | Th_module me -> - replay_mod ove (subst, ops, proofs, scope) (import, me) + replay_mod ove (subst, ops, proofs, scope) (item.ti_import, me) | Th_export (p, lc) -> - replay_export ove (subst, ops, proofs, scope) (import, p, lc) + replay_export ove (subst, ops, proofs, scope) (item.ti_import, p, lc) - | Th_baserw (x, lc) when not hidden -> - replay_baserw ove (subst, ops, proofs, scope) (import, x, lc) + | Th_baserw (x, lc) -> + replay_baserw ove (subst, ops, proofs, scope) (item.ti_import, x, lc) - | Th_addrw (p, l, lc) when not hidden -> - replay_addrw ove (subst, ops, proofs, scope) (import, p, l, lc) + | Th_addrw (p, l, lc) -> + replay_addrw ove (subst, ops, proofs, scope) (item.ti_import, p, l, lc) - | Th_reduction rules when not hidden -> - replay_reduction ove (subst, ops, proofs, scope) (import, rules) + | Th_reduction rules -> + replay_reduction ove (subst, ops, proofs, scope) (item.ti_import, rules) - | Th_auto at_base when not hidden -> - replay_auto ove (subst, ops, proofs, scope) (import, at_base) + | Th_auto { level = lvl; base; axioms = ps; locality = lc } -> + replay_auto ove (subst, ops, proofs, scope) (item.ti_import, lvl, base, ps, lc) - | Th_instance ((typ, ty), tc, lc) when not hidden -> - replay_instance ove (subst, ops, proofs, scope) (import, (typ, ty), tc, lc) + | Th_typeclass (x, tc) -> + replay_typeclass ove (subst, ops, proofs, scope) (item.ti_import, x, tc) - | Th_baserw _ - | Th_addrw _ - | Th_reduction _ - | Th_auto _ - | Th_instance _ -> - (subst, ops, proofs, scope) + | Th_instance (x, tci) -> + replay_instance ove (subst, ops, proofs, scope) (item.ti_import, x, tci) - | Th_alias (name, target) -> - replay_alias ove (subst, ops, proofs, scope) (item.ti_import, name, target) + | Th_alias (n, p) -> + let p = EcSubst.subst_path subst p in + let scope = + ove.ovre_hooks.hadd_item scope ~import:item.ti_import (Th_alias (n, p)) in + (subst, ops, proofs, scope) | Th_theory (ox, cth) -> begin let thmode = cth.cth_mode in let (subst, x) = rename ove subst (`Theory, ox) in let subovrds = Msym.find_opt ox ove.ovre_ovrd.evc_ths in - let subovrds = EcUtils.odfl (evc_empty, false) subovrds in - let subovrds, clear = subovrds in - let hidden = hidden || clear in + let subovrds, sub_clear = + EcUtils.odfl (evc_empty, false) subovrds in + let import = item.ti_import && not sub_clear in let subove = { ove with ovre_ovrd = subovrds; ovre_abstract = ove.ovre_abstract || (thmode = `Abstract); @@ -1117,10 +1102,9 @@ and replay1 (ove : _ ovrenv) (subst, ops, proofs, scope) (hidden, item) = let new_local = odfl cth.cth_loca ove.ovre_local in let subscope = ove.ovre_hooks.hthenter scope thmode x new_local in let (subst, ops, proofs, subscope) = - List.fold_left - (fun state item -> replay1 subove state (hidden, item)) + List.fold_left (replay1 subove) (subst, ops, proofs, subscope) cth.cth_items in - let scope = ove.ovre_hooks.hthexit ~import:(not hidden) subscope `Full in + let scope = ove.ovre_hooks.hthexit subscope ~import `Full in (subst, ops, proofs, scope) in (subst, ops, proofs, subscope) @@ -1129,7 +1113,7 @@ and replay1 (ove : _ ovrenv) (subst, ops, proofs, scope) (hidden, item) = (* -------------------------------------------------------------------- *) let replay (hooks : 'a ovrhooks) ~abstract ~override_locality ~incl ~clears ~renames - ~opath ~npath ovrds (scope : 'a) (name, hidden, items, base_local) + ~opath ~npath ovrds (scope : 'a) (name, items, base_local) = let subst = EcSubst.add_path EcSubst.empty ~src:opath ~dst:npath in let ove = { @@ -1145,13 +1129,12 @@ let replay (hooks : 'a ovrhooks) ovre_glproof = ovrds.evc_lemmas.ev_global; } in - let mode = if abstract then `Abstract else `Concrete in - let new_local = odfl base_local override_locality in - let scope = if incl then scope else hooks.hthenter scope mode name new_local in try + let mode = if abstract then `Abstract else `Concrete in + let new_local = odfl base_local override_locality in + let scope = if incl then scope else hooks.hthenter scope mode name new_local in let _, _, proofs, scope = - List.fold_left - (fun state item -> replay1 ove state (hidden, item)) + List.fold_left (replay1 ove) (subst, Mp.empty, [], scope) items in let scope = if incl then scope else hooks.hthexit scope ~import:true `No in (List.rev proofs, scope) diff --git a/src/ecTheoryReplay.mli b/src/ecTheoryReplay.mli index a5de45eace..dd18bf5877 100644 --- a/src/ecTheoryReplay.mli +++ b/src/ecTheoryReplay.mli @@ -33,13 +33,8 @@ and 'a ovrhooks = { (* -------------------------------------------------------------------- *) val replay : 'a ovrhooks - -> abstract:bool - -> override_locality:EcTypes.is_local option - -> incl:bool - -> clears:Sp.t - -> renames:(renaming list) - -> opath:path - -> npath:path - -> evclone - -> 'a -> symbol * bool * theory_item list * EcTypes.is_local + -> abstract:bool -> override_locality:EcTypes.is_local option -> incl:bool + -> clears:Sp.t -> renames:(renaming list) + -> opath:path -> npath:path -> evclone + -> 'a -> symbol * theory_item list * EcTypes.is_local -> axclone list * 'a diff --git a/src/ecTypeClass.ml b/src/ecTypeClass.ml index f142cc94d9..588dd380be 100644 --- a/src/ecTypeClass.ml +++ b/src/ecTypeClass.ml @@ -1,87 +1,413 @@ (* -------------------------------------------------------------------- *) -open EcUtils +open EcIdent open EcPath +open EcUtils +open EcAst +open EcTheory + +(* -------------------------------------------------------------------- *) +exception NoMatch (* -------------------------------------------------------------------- *) -type graph = { - tcg_nodes : Sp.t Mp.t; - tcg_closure : Sp.t Mp.t; -} +module TyMatch(E : sig val env : EcEnv.env end) = struct + let rec doit_type (map : ty option Mid.t) (pattern : ty) (ty : ty) = + let pattern = EcEnv.ty_hnorm pattern E.env in + let ty = EcEnv.ty_hnorm ty E.env in + + match pattern.ty_node, ty.ty_node with + | Tunivar _, _ -> + assert false + + (* Tunivar on the [ty] side is a wildcard: the goal type contains + a fresh univar that the unifier will resolve later. Don't fail + the match — leave the pattern's [Tvar] entries (if any) unbound + and let the caller decide whether the partial match is enough. *) + | _, Tunivar _ -> + map + + | Tvar a, _ -> begin + (* [a] may not be in [map] when the pattern carries free Tvars + (e.g. an instance whose carrier was a section-local tparam + that did not get generalised to [tci_params]). Treat that as + a non-match rather than crashing the inference loop. *) + match Mid.find_opt a map with + | None -> raise NoMatch + | Some None -> + Mid.add a (Some ty) map + + | Some (Some ty') -> + if not (EcCoreEqTest.for_type E.env ty ty') then + raise NoMatch; + map + + end + + | Tglob id1, Tglob id2 when EcIdent.id_equal id1 id2 -> + map + + | Tconstr (p, args), Tconstr (p', args') -> + if not (EcPath.p_equal p p') then + raise NoMatch; + doit_etyargs map args args' + + | Ttuple ptns, Ttuple tys when List.length ptns = List.length tys -> + doit_types map ptns tys -type nodes = { - tcn_graph : graph; - tcn_nodes : Sp.t; -} + | Tfun (p1, p2), Tfun (ty1, ty2) -> + doit_types map [p1; p2] [ty1; ty2] -type node = EcPath.path + | _, _ -> + raise NoMatch -exception CycleDetected + and doit_types (map : ty option Mid.t) (pts : ty list) (tys : ty list) = + List.fold_left2 doit_type map pts tys + + and doit_etyarg (map : ty option Mid.t) ((pattern, ptcws) : etyarg) ((ty, ttcws) : etyarg) = + let map = doit_type map pattern ty in + let map = doit_tcws map ptcws ttcws in + map + + and doit_etyargs (map : ty option Mid.t) (pts : etyarg list) (etys : etyarg list) = + List.fold_left2 doit_etyarg map pts etys + + and doit_tcw (map : ty option Mid.t) (ptcw : tcwitness) (ttcw : tcwitness) = + match ptcw, ttcw with + | TCIUni _, _ -> + assert false + + | TCIConcrete ptcw, TCIConcrete ttcw -> + if not (EcPath.p_equal ptcw.path ttcw.path) then + raise NoMatch; + doit_etyargs map ptcw.etyargs ttcw.etyargs + + | TCIAbstract _, TCIAbstract _ -> + if not (EcAst.tcw_equal ptcw ttcw) then + raise NoMatch; + map + + | _, _ -> + raise NoMatch + + and doit_tcws (map : ty option Mid.t) (ptcws : tcwitness list) (ttcws : tcwitness list) = + List.fold_left2 doit_tcw map ptcws ttcws +end + +(* -------------------------------------------------------------------- *) +let ty_match (env : EcEnv.env) (params : ident list) ~(pattern : ty) ~(ty : ty) = + let module M = TyMatch(struct let env = env end) in + let map = Mid.of_list (List.map (fun a -> (a, None)) params) in + M.doit_type map pattern ty + +(* -------------------------------------------------------------------- *) +let etyargs_match + (env : EcEnv.env) + (params : ident list) + ~(patterns : etyarg list) + ~(etyargs : etyarg list) += + let module M = TyMatch(struct let env = env end) in + let map = Mid.of_list (List.map (fun a -> (a, None)) params) in + M.doit_etyargs map patterns etyargs (* -------------------------------------------------------------------- *) -module Graph = struct - let empty : graph = { - tcg_nodes = Mp.empty; - tcg_closure = Mp.empty; - } +let rec check_tcinstance + (env : EcEnv.env) + (ty : ty) + (tc : typeclass) + ((p, tci) : path option * tcinstance) += + let exception Bailout in - let dump gr = - Printf.sprintf "%s\n" - (String.concat "\n" - (List.map - (fun (p, ps) -> Printf.sprintf "%s -> %s" - (EcPath.tostring p) - (String.concat ", " (List.map EcPath.tostring (Sp.elements ps)))) - (Mp.bindings gr.tcg_nodes))) + try + let p = oget ~exn:Bailout p in - let has_path ~src ~dst g = - if EcPath.p_equal src dst then - true + let tgargs = + match tci.tci_instance with + | `General (tgp, _) -> + if not (EcPath.p_equal tc.tc_name tgp.tc_name) then + raise Bailout; + tgp.tc_args + | _ -> raise Bailout in + + let map = + etyargs_match env (List.fst tci.tci_params) + ~patterns:tgargs ~etyargs:tc.tc_args in + + let map = + let module M = TyMatch(struct let env = env end) in + M.doit_type map tci.tci_type ty in + + let _, args = List.fold_left_map (fun subst (a, aargs) -> + let aty = oget ~exn:Bailout (Mid.find a map) in + let aargs = List.map (fun aarg -> + let aarg = EcCoreSubst.Tvar.subst_tc subst aarg in + oget ~exn:Bailout (infer env aty aarg) + ) aargs in + let subst = Mid.add a (aty, aargs) subst in + (subst, (aty, aargs)) + ) Mid.empty tci.tci_params in + + Some (TCIConcrete { path = p; etyargs = args; lift = []; }) + + with Bailout | NoMatch -> None + +(* -------------------------------------------------------------------- *) +(* Walk the parent DAG of [tc'] looking for [tc]. Returns the first + path (list of parent-edge indices) reaching [tc], or [None]. With + single-parent inheritance this is the all-zeros path; with + multi-parent classes the path encodes which parent is taken at + each step. Mirrors [match_tc_offset]'s walk in [EcUnify]. *) +and lift_to_tc (env : EcEnv.env) (tc' : typeclass) (tc : typeclass) : int list option = + let eq_tc t = + EcPath.p_equal tc.tc_name t.tc_name + && List.length tc.tc_args = List.length t.tc_args + && List.for_all2 + (fun (a, _) (b, _) -> EcCoreEqTest.for_type env a b) + tc.tc_args t.tc_args in + let rec walk t path = + if eq_tc t then Some (List.rev path) else - match Mp.find_opt src g.tcg_closure with - | None -> false - | Some m -> Mp.mem dst m - - let add ~src ~dst g = - if has_path ~src ~dst g then - raise CycleDetected; - - match Mp.find_opt src g.tcg_nodes with - | Some m when Mp.mem dst m -> g - | _ -> - let up_node m = Sp.add dst (odfl Sp.empty m) - and up_clos m = - Sp.union - (odfl Sp.empty (Mp.find_opt dst g.tcg_closure)) - (Sp.add dst (odfl Sp.empty m)) - in - { g with - tcg_nodes = Mp.change (some -| up_node) src g.tcg_nodes; - tcg_closure = Mp.change (some -| up_clos) src g.tcg_closure; } -end + let decl = EcEnv.TypeClass.by_path t.tc_name env in + let subst = + List.fold_left2 + (fun s (a, _) etyarg -> Mid.add a etyarg s) + Mid.empty decl.tc_tparams t.tc_args in + let rec try_parents i = function + | [] -> None + | (parent, _ren) :: rest -> + let parent = EcCoreSubst.Tvar.subst_tc subst parent in + (match walk parent (i :: path) with + | Some _ as r -> r + | None -> try_parents (i + 1) rest) + in try_parents 0 decl.tc_prts + in walk tc' [] (* -------------------------------------------------------------------- *) -module Nodes = struct - let empty g = { - tcn_graph = g; - tcn_nodes = Sp.empty; - } +(* Mode-#6 fallback: when [ty] is [Tconstr p _] and [p]'s declaration + is [`Abstract tcs] (e.g. a section-declared abstract type with + class bounds, like [declare type c <: comring]), build the + [TCIAbstract { support = `Abs p; offset; lift }] witness by finding + an entry in [tcs] that reaches [tc] via its parent DAG. Without + this, [infer]'s recursion on a parametric-instance's tparam + constraint fails for section-abstract carriers (Path B in the + resolver), even though [EcUnify] handles the same case via its + [strat_abs_via_decl]. *) +and infer_via_abs_decl (env : EcEnv.env) (ty : ty) (tc : typeclass) : tcwitness option = + match ty.ty_node with + | Tconstr (p, _) -> begin + match EcEnv.Ty.by_path_opt p env with + | Some { tyd_type = `Abstract tcs; _ } -> + let rec find_offset i = function + | [] -> None + | tc' :: rest -> + (match lift_to_tc env tc' tc with + | Some lift -> + Some (TCIAbstract { support = `Abs p; offset = i; lift }) + | None -> find_offset (i + 1) rest) + in find_offset 0 tcs + | _ -> None + end + | _ -> None - let add n nodes = - let module E = struct exception Discard end in +(* -------------------------------------------------------------------- *) +and infer (env : EcEnv.env) (ty : ty) (tc : typeclass) = + match + List.find_map_opt + (check_tcinstance env ty tc) + (EcEnv.TcInstance.get_all env) + with + | Some _ as w -> w + | None -> infer_via_abs_decl env ty tc - try - let aout = - Sp.filter - (fun p -> - if Graph.has_path ~src:p ~dst:n nodes.tcn_graph then raise E.Discard; - not (Graph.has_path ~src:n ~dst:p nodes.tcn_graph)) - nodes.tcn_nodes - in - { nodes with tcn_nodes = Sp.add n aout } - with E.Discard -> nodes +(* -------------------------------------------------------------------- *) +(* Like [infer] but returns ALL matching instances as witnesses. Used + to detect ambiguity (multi-flavor inheritance, e.g. a comring with + both addmonoid- and mulmonoid-derived monoid views on the same + carrier) — the caller may then defer commitment until other + unification steps narrow the choice. *) +and infer_all (env : EcEnv.env) (ty : ty) (tc : typeclass) = + List.filter_map + (check_tcinstance env ty tc) + (EcEnv.TcInstance.get_all env) - let toset nodes = nodes.tcn_nodes +(* -------------------------------------------------------------------- *) +(* Build one [tcwitness] per entry of [tcs] for a carrier [body], + suitable for plugging into the [tcwitness list] slot of an + [add_tydef] binding. The expected witness for [body : tc] is + queried via [infer]; if no instance is registered, falls back to + a [`Abs body_path] / [`Var a] placeholder so the substitution + matches the pre-fix shape. With this fallback the helper is + non-failing — callers that want to error on a missing instance + should check [infer] separately. *) +let witnesses_for_body + (env : EcEnv.env) (body : ty) (tcs : typeclass list) + : tcwitness list += + List.map (fun tc -> + match infer env body tc with + | Some w -> w + | None -> + let support = + match body.ty_node with + | Tvar a -> `Var a + | Tconstr (p, _) -> `Abs p + | _ -> + (* Last-ditch dummy; should never arise for sensible + clone bodies, which are always [Tvar] or [Tconstr]. *) + `Abs (EcPath.psymbol "?") in + TCIAbstract { support; offset = 0; lift = [] } + ) tcs - let reduce set g = - toset (Sp.fold add set (empty g)) -end +(* -------------------------------------------------------------------- *) +(* Match a candidate instance against [tc] on its arguments only, + leaving the carrier ([tci.tci_type]) for the caller to unify with + the goal carrier. Returns the partial type-substitution that + pinned the [tci_params] from the match. *) +let candidates_by_args (env : EcEnv.env) (tc : typeclass) + : (EcPath.path option * tcinstance * ty option EcIdent.Mid.t) list += + let try_one (p, tci) = + match tci.tci_instance with + | `General (tgp, _) when EcPath.p_equal tc.tc_name tgp.tc_name -> begin + try + let map = + etyargs_match env (List.fst tci.tci_params) + ~patterns:tgp.tc_args ~etyargs:tc.tc_args + in Some (p, tci, map) + with NoMatch -> None + end + | _ -> None + in List.filter_map try_one (EcEnv.TcInstance.get_all env) + +(* -------------------------------------------------------------------- *) +(* Flatten the parent DAG of a typeclass into a deduplicated list, + self first. With single-inheritance this is the linear chain + [tc; parent; grandparent; ...]; with multi-inheritance it's a + BFS walk: [tc; parent_1; ...; parent_n; ...grandparents...]. + Each ancestor's [tc_args] is substituted along the path so the + args reference [tc]'s tparams. Duplicates are dropped (an ancestor + reachable via multiple paths appears once, at the shortest path). *) +let ancestors (env : EcEnv.env) (tc : typeclass) : typeclass list = + let parents (tc : typeclass) : typeclass list = + let decl = EcEnv.TypeClass.by_path tc.tc_name env in + let subst = + List.fold_left2 + (fun s (a, _) etyarg -> Mid.add a etyarg s) + Mid.empty decl.tc_tparams tc.tc_args in + List.map (fun (p, _ren) -> EcCoreSubst.Tvar.subst_tc subst p) decl.tc_prts in + let same (a : typeclass) (b : typeclass) = + EcPath.p_equal a.tc_name b.tc_name in + let rec bfs (frontier : typeclass list) (acc : typeclass list) = + match frontier with + | [] -> List.rev acc + | tc :: rest -> + if List.exists (same tc) acc then bfs rest acc + else bfs (rest @ parents tc) (tc :: acc) + in bfs [tc] [] + +(* -------------------------------------------------------------------- *) +(* Compose two renamings. + [outer] is declared on a parent edge: maps grandparent op names + to parent op names (only listed entries are renamed; unlisted + passes through identity). + [inner] is the accumulated renaming on the child side: maps + parent op names to child op names. + Result: grandparent op names → child op names. + + Two cases: + - For each (gp_name, p_name) in outer: child's name for that op + is [inner(p_name)], defaulting to [p_name] if unlisted. + - For each (p_name, c_name) in inner whose [p_name] is NOT + referenced in outer (neither as a value nor as a key): the op + passes through outer as identity, so grandparent's name for it + is [p_name] and child's name is [c_name]. Add [(p_name, c_name)]. *) +let compose_renaming + ~(outer : (EcSymbols.symbol * EcSymbols.symbol) list) + ~(inner : (EcSymbols.symbol * EcSymbols.symbol) list) + : (EcSymbols.symbol * EcSymbols.symbol) list += + let inner_map = EcMaps.Mstr.of_list inner in + let from_outer = + List.map + (fun (gp_name, p_name) -> + let c_name = odfl p_name (EcMaps.Mstr.find_opt p_name inner_map) in + (gp_name, c_name)) + outer in + let outer_p_names = + List.fold_left (fun s (_, p) -> EcMaps.Sstr.add p s) EcMaps.Sstr.empty outer in + let outer_gp_names = + List.fold_left (fun s (gp, _) -> EcMaps.Sstr.add gp s) EcMaps.Sstr.empty outer in + let from_inner = + List.filter_map + (fun (p_name, c_name) -> + if EcMaps.Sstr.mem p_name outer_p_names || EcMaps.Sstr.mem p_name outer_gp_names + then None + else Some (p_name, c_name)) + inner in + from_outer @ from_inner + +(* -------------------------------------------------------------------- *) +(* True iff op [n] survives the cumulative ancestor→child renaming + [ren] under the same name. An op is preserved when [ren] doesn't + mention it (passes through as identity), or when it explicitly + maps to itself. *) +let op_preserved + (ren : (EcSymbols.symbol * EcSymbols.symbol) list) + (n : EcSymbols.symbol) + : bool += + match List.assoc_opt n ren with + | None -> true + | Some n' -> n = n' + +(* -------------------------------------------------------------------- *) +(* Variant of [ancestors] that also returns the cumulative op renaming + accumulated along the BFS walk from [tc] to each ancestor. The + renaming maps the ancestor's op names to the corresponding op + names declared in (or inherited by) [tc]. *) +let ancestors_with_renaming + (env : EcEnv.env) (tc : typeclass) + : (typeclass * (EcSymbols.symbol * EcSymbols.symbol) list) list += + let parents tc = + let decl = EcEnv.TypeClass.by_path tc.tc_name env in + let subst = + List.fold_left2 + (fun s (a, _) etyarg -> Mid.add a etyarg s) + Mid.empty decl.tc_tparams tc.tc_args in + List.map + (fun (p, ren) -> (EcCoreSubst.Tvar.subst_tc subst p, ren)) + decl.tc_prts in + (* Compose two renamings. + [outer] is declared on a parent edge: maps grandparent op names + to parent op names (only listed entries are renamed; unlisted + passes through identity). + [inner] is the accumulated renaming on the child side: maps + parent op names to child op names. + Result: grandparent op names → child op names. + + Two cases: + - For each (gp_name, p_name) in outer: child's name for that op + is [inner(p_name)], defaulting to [p_name] if unlisted. + - For each (p_name, c_name) in inner whose [p_name] is NOT + referenced in outer (neither as a value nor as a key): the op + passes through outer as identity, so grandparent's name for it + is [p_name] and child's name is [c_name]. Add [(p_name, c_name)]. *) + let compose = compose_renaming in + let ren_eq r1 r2 = + List.length r1 = List.length r2 + && List.for_all2 (fun (a, b) (c, d) -> a = c && b = d) r1 r2 in + let same (a, ra) (b, rb) = + EcPath.p_equal a.tc_name b.tc_name && ren_eq ra rb in + let rec bfs frontier acc = + match frontier with + | [] -> List.rev acc + | (tc, ren) :: rest -> + if List.exists (same (tc, ren)) acc then bfs rest acc + else + let next = + List.map + (fun (p, p_ren) -> (p, compose ~outer:p_ren ~inner:ren)) + (parents tc) in + bfs (rest @ next) ((tc, ren) :: acc) + in bfs [(tc, [])] [] diff --git a/src/ecTypeClass.mli b/src/ecTypeClass.mli index 9c8b566600..56fc482b56 100644 --- a/src/ecTypeClass.mli +++ b/src/ecTypeClass.mli @@ -1,23 +1,78 @@ (* -------------------------------------------------------------------- *) -open EcPath +open EcAst +open EcTheory +open EcEnv -type node = path +(* -------------------------------------------------------------------- *) +exception NoMatch + +(* -------------------------------------------------------------------- *) +val infer : env -> ty -> typeclass -> tcwitness option + +(* -------------------------------------------------------------------- *) +(* Match [pattern] (with free [Tvar]s listed in [params]) against [ty] + and return the resulting substitution. Raises [NoMatch] on shape + mismatch. *) +val ty_match : + env -> EcIdent.t list + -> pattern:ty -> ty:ty + -> ty option EcIdent.Mid.t -type graph -type nodes +(* -------------------------------------------------------------------- *) +(* Build one [tcwitness] per entry of [tcs] for a carrier [body], + suitable for plugging into the witness slot of an [add_tydef] + binding. Each witness is queried via [infer]; on lookup failure, + falls back to a [`Abs body_path] / [`Var a] placeholder so the + substitution preserves the pre-fix shape (no regression for + TC-free callers). *) +val witnesses_for_body : + env -> ty -> typeclass list -> tcwitness list -exception CycleDetected +(* -------------------------------------------------------------------- *) +(* All matching instances as witnesses (vs. [infer] which returns the + first). Used to detect ambiguity from multi-flavor inheritance. *) +val infer_all : env -> ty -> typeclass -> tcwitness list -module Graph : sig - val empty : graph - val add : src:node -> dst:node -> graph -> graph - val has_path : src:node -> dst:node -> graph -> bool - val dump : graph -> string -end +(* -------------------------------------------------------------------- *) +(* Like [infer], but the carrier may be left abstract: only the + typeclass arguments are matched. Returns the matching instance(s) + with the partial type-substitution that pinned each argument; the + caller must then unify the carrier with [subst tci_type] and recover + the witness. Used by the "infer-by-args" strategy of the unifier + when the carrier is a fresh type univar. *) +val candidates_by_args : + env -> typeclass + -> (EcPath.path option * tcinstance * ty option EcIdent.Mid.t) list -module Nodes : sig - val empty : graph -> nodes - val add : node -> nodes -> nodes - val toset : nodes -> Sp.t - val reduce : Sp.t -> graph -> Sp.t -end +(* -------------------------------------------------------------------- *) +(* Flatten the parent chain: [tc; tc.parent; tc.grandparent; ...]. + Args are substituted along the chain. *) +val ancestors : env -> typeclass -> typeclass list + +(* -------------------------------------------------------------------- *) +(* Like [ancestors], but each ancestor is paired with the cumulative + op renaming accumulated along the BFS walk from [tc]. The renaming + is a list of (ancestor_op_name, local_op_name) pairs. Empty list + means no renaming (plain inheritance). *) +val ancestors_with_renaming : + env -> typeclass -> (typeclass * (EcSymbols.symbol * EcSymbols.symbol) list) list + +(* -------------------------------------------------------------------- *) +(* Compose two cumulative renamings. [outer] is the renaming on a + parent edge (grandparent op → parent op); [inner] is the + already-accumulated renaming on the child side (parent op → child + op). Result maps grandparent op names to child op names. *) +val compose_renaming : + outer:(EcSymbols.symbol * EcSymbols.symbol) list + -> inner:(EcSymbols.symbol * EcSymbols.symbol) list + -> (EcSymbols.symbol * EcSymbols.symbol) list + +(* -------------------------------------------------------------------- *) +(* [op_preserved ren n] is true iff applying the cumulative + ancestor→child renaming [ren] to op name [n] leaves it as [n] (or + doesn't mention [n] at all). Used to filter parent-DAG paths when + resolving a TC witness for a specific named op: only paths whose + cumulative renaming preserves the op name expose that op under + the same name at the carrier site. *) +val op_preserved : + (EcSymbols.symbol * EcSymbols.symbol) list -> EcSymbols.symbol -> bool diff --git a/src/ecTypes.ml b/src/ecTypes.ml index 644673bbd9..58fd55c5a4 100644 --- a/src/ecTypes.ml +++ b/src/ecTypes.ml @@ -42,7 +42,7 @@ let rec dump_ty ty = EcIdent.tostring_internal p | Tunivar i -> - Printf.sprintf "#%d" i + Printf.sprintf "#%d" (i :> int) | Tvar id -> EcIdent.tostring_internal id @@ -52,17 +52,18 @@ let rec dump_ty ty = | Tconstr (p, tys) -> Printf.sprintf "%s[%s]" (EcPath.tostring p) - (String.concat ", " (List.map dump_ty tys)) + (String.concat ", " (List.map dump_ty (List.fst tys))) | Tfun (t1, t2) -> Printf.sprintf "(%s) -> (%s)" (dump_ty t1) (dump_ty t2) (* -------------------------------------------------------------------- *) -let tuni uid = mk_ty (Tunivar uid) -let tvar id = mk_ty (Tvar id) -let tconstr p lt = mk_ty (Tconstr (p, lt)) -let tfun t1 t2 = mk_ty (Tfun (t1, t2)) -let tglob m = mk_ty (Tglob m) +let tuni uid = mk_ty (Tunivar uid) +let tvar id = mk_ty (Tvar id) +let tconstr p lt = mk_ty (Tconstr (p, List.map (fun ty -> (ty, [])) lt)) +let tconstr_tc p lt = mk_ty (Tconstr (p, lt)) +let tfun t1 t2 = mk_ty (Tfun (t1, t2)) +let tglob m = mk_ty (Tglob m) (* -------------------------------------------------------------------- *) let tunit = tconstr EcCoreLib.CI_Unit .p_unit [] @@ -104,7 +105,7 @@ let rec tyfun_flat (ty : ty) = (* -------------------------------------------------------------------- *) let as_tdistr (ty : ty) = match ty.ty_node with - | Tconstr (p, [sty]) + | Tconstr (p, [sty, []]) when EcPath.p_equal p EcCoreLib.CI_Distr.p_distr -> Some sty @@ -113,7 +114,7 @@ let as_tdistr (ty : ty) = let is_tdistr (ty : ty) = as_tdistr ty <> None (* -------------------------------------------------------------------- *) -let ty_map f t = +let rec ty_map (f : ty -> ty) (t : ty) : ty = match t.ty_node with | Tglob _ | Tunivar _ | Tvar _ -> t @@ -121,39 +122,88 @@ let ty_map f t = ttuple (List.Smart.map f lty) | Tconstr (p, lty) -> - let lty = List.Smart.map f lty in - tconstr p lty + let lty = List.Smart.map (etyarg_map f) lty in + tconstr_tc p lty | Tfun (t1, t2) -> tfun (f t1) (f t2) -let ty_fold f s ty = +and etyarg_map (f : ty -> ty) ((ty, tcw) : etyarg) : etyarg = + let ty = f ty in + let tcw = List.Smart.map (tcw_map f) tcw in + (ty, tcw) + +and tcw_map (f : ty -> ty) (tcw : tcwitness) : tcwitness = + match tcw with + | TCIUni _ -> + tcw + + | TCIConcrete ({ etyargs; _ } as c) -> + let etyargs = List.Smart.map (etyarg_map f) etyargs in + TCIConcrete { c with etyargs } + + | TCIAbstract _ -> + tcw + +(* -------------------------------------------------------------------- *) +let rec ty_fold (f : 'a -> ty -> 'a) (v : 'a) (ty : ty) : 'a = match ty.ty_node with - | Tglob _ | Tunivar _ | Tvar _ -> s - | Ttuple lty -> List.fold_left f s lty - | Tconstr(_, lty) -> List.fold_left f s lty - | Tfun(t1,t2) -> f (f s t1) t2 + | Tglob _ | Tunivar _ | Tvar _ -> v + | Ttuple lty -> List.fold_left f v lty + | Tconstr (_, lty) -> List.fold_left (etyarg_fold f) v lty + | Tfun (t1, t2) -> f (f v t1) t2 + +and etyarg_fold (f : 'a -> ty -> 'a) (v : 'a) (ety : etyarg) : 'a = + let (ty, tcw) = ety in + List.fold_left (tcw_fold f) (f v ty) tcw + +and tcw_fold (f : 'a -> ty -> 'a) (v : 'a) (tcw : tcwitness) : 'a = + match tcw with + | TCIConcrete { etyargs } -> + List.fold_left (etyarg_fold f) v etyargs + + | TCIUni _ | TCIAbstract _ -> + v -let ty_sub_exists f t = - match t.ty_node with - | Tglob _ | Tunivar _ | Tvar _ -> false - | Ttuple lty -> List.exists f lty - | Tconstr (_, lty) -> List.exists f lty - | Tfun (t1, t2) -> f t1 || f t2 +(* -------------------------------------------------------------------- *) +let ty_iter (f : ty -> unit) (ty : ty) : unit = + ty_fold (fun () -> f) () ty -let ty_iter f t = - match t.ty_node with - | Tglob _ | Tunivar _ | Tvar _ -> () - | Ttuple lty -> List.iter f lty - | Tconstr (_, lty) -> List.iter f lty - | Tfun (t1,t2) -> f t1; f t2 +let etyarg_iter (f : ty -> unit) (ety : etyarg) : unit = + etyarg_fold (fun () -> f) () ety +let tcw_iter (f : ty -> unit) (tcw : tcwitness) : unit = + tcw_fold (fun () -> f) () tcw + +(* -------------------------------------------------------------------- *) +let ty_sub_exists (f : ty -> bool) (ty : ty) = + let exception Exists in + try + ty_iter (fun ty -> if f ty then raise Exists) ty; + false + with Exists -> true + +let etyarg_sub_exists (f : ty -> bool) (ety : etyarg) = + let exception Exists in + try + etyarg_iter (fun ty -> if f ty then raise Exists) ety; + false + with Exists -> true + +let tcw_sub_exists (f : ty -> bool) (tcw : tcwitness) = + let exception Exists in + try + tcw_iter (fun ty -> if f ty then raise Exists) tcw; + false + with Exists -> true + +(* -------------------------------------------------------------------- *) exception FoundUnivar -let rec ty_check_uni t = - match t.ty_node with +let rec ty_check_uni (ty : ty) : unit = + match ty.ty_node with | Tunivar _ -> raise FoundUnivar - | _ -> ty_iter ty_check_uni t + | _ -> ty_iter ty_check_uni ty let rec var_mem ?(check_glob = false) id t = match t.ty_node with @@ -204,7 +254,6 @@ let ovar_of_var { v_name = n; v_type = t } = { ov_name = Some n; ov_type = t } module Tvar = struct - let rec fv_rec fv t = match t.ty_node with | Tvar id -> Sid.add id fv @@ -216,6 +265,34 @@ end let ty_fv_and_tvar (ty : ty) = EcIdent.fv_union ty.ty_fv (Mid.map (fun () -> 1) (Tvar.fv ty)) +(* -------------------------------------------------------------------- *) +let rec etyargs_tvar_fv (etyargs : etyarg list) = + List.fold_left + (fun fv etyarg -> Sid.union fv (etyarg_tvar_fv etyarg)) + Sid.empty etyargs + +and etyarg_tvar_fv ((ty, tcws) : etyarg) : Sid.t = + Sid.union (Tvar.fv ty) (tcws_tvar_fv tcws) + +and tcws_tvar_fv (tcws : tcwitness list) = + List.fold_left + (fun fv tcw -> Sid.union fv (tcw_tvar_fv tcw)) + Sid.empty tcws + +and tcw_tvar_fv (tcw : tcwitness) : Sid.t = + match tcw with + | TCIUni _ -> + Sid.empty + + | TCIConcrete { etyargs } -> + etyargs_tvar_fv etyargs + + | TCIAbstract { support = `Var tyvar } -> + Sid.singleton tyvar + + | TCIAbstract { support = (`Abs _) } -> + Sid.empty + (* -------------------------------------------------------------------- *) type pvar_kind = EcAst.pvar_kind @@ -317,38 +394,54 @@ let lp_bind = function List.pmap (fun (x, ty) -> omap (fun x -> (x, ty)) x) b (* -------------------------------------------------------------------- *) -type expr = EcAst.expr - +type expr = EcAst.expr type expr_node = EcAst.expr_node - type equantif = EcAst.equantif type ebinding = EcAst.ebinding type ebindings = EcAst.ebindings type closure = (EcIdent.t * ty) list * expr +(* -------------------------------------------------------------------- *) +type etyarg = EcAst.etyarg + +let etyarg_fv = EcAst.etyarg_fv +let etyargs_fv = EcAst.etyargs_fv +let etyarg_hash = EcAst.etyarg_hash +let etyarg_equal = EcAst.etyarg_equal + +(* -------------------------------------------------------------------- *) +type tcwitness = EcAst.tcwitness + +let tcw_fv = EcAst.tcw_fv +let tcw_hash = EcAst.tcw_hash +let tcw_equal = EcAst.tcw_equal + (* -------------------------------------------------------------------- *) let e_equal = EcAst.e_equal -let e_hash = EcAst.e_hash let e_compare = fun e1 e2 -> e_hash e1 - e_hash e2 let e_fv = EcAst.e_fv +let e_hash = EcAst.e_hash let e_ty e = e.e_ty (* -------------------------------------------------------------------- *) let lp_fv = EcAst.lp_fv - let pv_fv = EcAst.pv_fv (* -------------------------------------------------------------------- *) let eqt_equal = EcAst.eqt_equal -(* -------------------------------------------------------------------- *) - let e_tt = mk_expr (Eop (EcCoreLib.CI_Unit.p_tt, [])) tunit let e_int = fun i -> mk_expr (Eint i) tint let e_local = fun x ty -> mk_expr (Elocal x) ty let e_var = fun x ty -> mk_expr (Evar x) ty -let e_op = fun x targs ty -> mk_expr (Eop (x, targs)) ty + +let e_op_tc x targs ty = + mk_expr (Eop (x, targs)) ty + +let e_op x targs ty = + e_op_tc x (List.map (fun ty -> (ty, [])) targs) ty + let e_let = fun pt e1 e2 -> mk_expr (Elet (pt, e1, e2)) e2.e_ty let e_tuple = fun es -> match es with @@ -366,13 +459,6 @@ let e_proj_simpl e i ty = | _ -> e_proj e i ty let e_quantif q b e = - if List.is_empty b then e else - - let b, e = - match e.e_node with - | Equant (q', b', e) when eqt_equal q q' -> (b@b', e) - | _ -> b, e in - let ty = match q with | `ELambda -> toarrow (List.map snd b) e.e_ty @@ -385,11 +471,7 @@ let e_exists b e = e_quantif `EExists b e let e_lam b e = e_quantif `ELambda b e let e_app x args ty = - if args = [] then x - else - match x.e_node with - | Eapp(x', args') -> mk_expr (Eapp (x', (args'@args))) ty - | _ -> mk_expr (Eapp (x, args)) ty + mk_expr (Eapp (x, args)) ty let e_app_op ?(tyargs=[]) op args ty = e_app (e_op op tyargs (toarrow (List.map e_ty args) ty)) args ty @@ -448,54 +530,33 @@ let e_oget (e : expr) (ty : ty) : expr = e_app op [e] ty (* -------------------------------------------------------------------- *) -let e_map fty fe e = +let e_map (ft : ty -> ty) (fe : expr -> expr) (e : expr) : expr = match e.e_node with - | Eint _ | Elocal _ | Evar _ -> e - - | Eop (p, tys) -> - let tys' = List.Smart.map fty tys in - let ty' = fty e.e_ty in - e_op p tys' ty' + | Eint _ -> e + | Elocal _ -> e + | Evar _ -> e + | Eop _ -> e | Eapp (e1, args) -> - let e1' = fe e1 in - let args' = List.Smart.map fe args in - let ty' = fty e.e_ty in - e_app e1' args' ty' + e_app (fe e1) (List.Smart.map fe args) (ft e.e_ty) | Elet (lp, e1, e2) -> - let e1' = fe e1 in - let e2' = fe e2 in - e_let lp e1' e2' + e_let lp (fe e1) (fe e2) | Etuple le -> - let le' = List.Smart.map fe le in - e_tuple le' + e_tuple (List.Smart.map fe le) | Eproj (e1, i) -> - let e' = fe e1 in - let ty = fty e.e_ty in - e_proj e' i ty + e_proj (fe e1) i (ft e.e_ty) | Eif (e1, e2, e3) -> - let e1' = fe e1 in - let e2' = fe e2 in - let e3' = fe e3 in - e_if e1' e2' e3' + e_if (fe e1) (fe e2) (fe e3) - | Ematch (b, es, ty) -> - let ty' = fty ty in - let b' = fe b in - let es' = List.Smart.map fe es in - e_match b' es' ty' + | Ematch (e, bs, ty) -> + e_match (fe e) (List.Smart.map fe bs) (ft ty) | Equant (q, b, bd) -> - let dop (x, ty as xty) = - let ty' = fty ty in - if ty == ty' then xty else (x, ty') in - let b' = List.Smart.map dop b in - let bd' = fe bd in - e_quantif q b' bd' + e_quantif q b (fe bd) let e_fold (fe : 'a -> expr -> 'a) (state : 'a) (e : expr) = match e.e_node with @@ -514,6 +575,7 @@ let e_fold (fe : 'a -> expr -> 'a) (state : 'a) (e : expr) = let e_iter (fe : expr -> unit) (e : expr) = e_fold (fun () e -> fe e) () e +(* -------------------------------------------------------------------- *) module MSHe = EcMaps.MakeMSH(struct type t = expr let tag e = e.e_tag end) module Me = MSHe.M module Se = MSHe.S @@ -564,3 +626,4 @@ let split_args e = match e.e_node with | Eapp (e, args) -> (e, args) | _ -> (e, []) + \ No newline at end of file diff --git a/src/ecTypes.mli b/src/ecTypes.mli index cd45abdd1e..c951ea906e 100644 --- a/src/ecTypes.mli +++ b/src/ecTypes.mli @@ -1,4 +1,6 @@ (* -------------------------------------------------------------------- *) + +open EcAst open EcBigInt open EcMaps open EcSymbols @@ -27,13 +29,14 @@ val dump_ty : ty -> string val ty_equal : ty -> ty -> bool val ty_hash : ty -> int -val tuni : EcUid.uid -> ty -val tvar : EcIdent.t -> ty -val ttuple : ty list -> ty -val tconstr : EcPath.path -> ty list -> ty -val tfun : ty -> ty -> ty -val tglob : EcIdent.t -> ty -val tpred : ty -> ty +val tuni : tyuni -> ty +val tvar : EcIdent.t -> ty +val ttuple : ty list -> ty +val tconstr : EcPath.path -> ty list -> ty +val tconstr_tc : EcPath.path -> EcAst.etyarg list -> ty +val tfun : ty -> ty -> ty +val tglob : EcIdent.t -> ty +val tpred : ty -> ty val ty_fv_and_tvar : ty -> int Mid.t @@ -65,20 +68,30 @@ exception FoundUnivar val ty_check_uni : ty -> unit (* -------------------------------------------------------------------- *) - module Tvar : sig - val fv : ty -> Sid.t + val fv : ty -> Sid.t end (* -------------------------------------------------------------------- *) (* [map f t] applies [f] on strict subterms of [t] (not recursive) *) val ty_map : (ty -> ty) -> ty -> ty +val etyarg_map : (ty -> ty) -> etyarg -> etyarg +val tcw_map : (ty -> ty) -> tcwitness -> tcwitness (* [sub_exists f t] true if one of the strict-subterm of [t] valid [f] *) val ty_sub_exists : (ty -> bool) -> ty -> bool +val etyarg_sub_exists : (ty -> bool) -> etyarg -> bool +val tcw_sub_exists : (ty -> bool) -> tcwitness -> bool +(* -------------------------------------------------------------------- *) val ty_fold : ('a -> ty -> 'a) -> 'a -> ty -> 'a +val etyarg_fold : ('a -> ty -> 'a) -> 'a -> etyarg -> 'a +val tcw_fold : ('a -> ty -> 'a) -> 'a -> tcwitness -> 'a + +(* -------------------------------------------------------------------- *) val ty_iter : (ty -> unit) -> ty -> unit +val etyarg_iter : (ty -> unit) -> etyarg -> unit +val tcw_iter : (ty -> unit) -> tcwitness -> unit val var_mem : ?check_glob:bool -> EcIdent.t -> ty -> bool @@ -161,6 +174,27 @@ type closure = (EcIdent.t * ty) list * expr (* -------------------------------------------------------------------- *) val eqt_equal : equantif -> equantif -> bool +(* -------------------------------------------------------------------- *) +type etyarg = EcAst.etyarg + +val etyarg_fv : etyarg -> int Mid.t +val etyargs_fv : etyarg list -> int Mid.t +val etyarg_hash : etyarg -> int +val etyarg_equal : etyarg -> etyarg -> bool + +(* -------------------------------------------------------------------- *) +type tcwitness = EcAst.tcwitness + +val tcw_fv : tcwitness -> int Mid.t +val tcw_hash : tcwitness -> int +val tcw_equal : tcwitness -> tcwitness -> bool + +(* -------------------------------------------------------------------- *) +val etyargs_tvar_fv : etyarg list -> Sid.t +val etyarg_tvar_fv : etyarg -> Sid.t +val tcws_tvar_fv : tcwitness list -> Sid.t +val tcw_tvar_fv : tcwitness -> Sid.t + (* -------------------------------------------------------------------- *) val e_equal : expr -> expr -> bool val e_compare : expr -> expr -> int @@ -174,6 +208,7 @@ val e_int : zint -> expr val e_decimal : zint * (int * zint) -> expr val e_local : EcIdent.t -> ty -> expr val e_var : prog_var -> ty -> expr +val e_op_tc : EcPath.path -> etyarg list -> ty -> expr val e_op : EcPath.path -> ty list -> ty -> expr val e_app : expr -> expr list -> ty -> expr val e_not : expr -> expr @@ -212,7 +247,7 @@ val split_args : expr -> expr * expr list (* -------------------------------------------------------------------- *) val e_map : - (ty -> ty ) (* 1-subtype op. *) + (ty -> ty) (* 1-type op. *) -> (expr -> expr) (* 1-subexpr op. *) -> expr -> expr @@ -221,5 +256,3 @@ val e_fold : ('state -> expr -> 'state) -> 'state -> expr -> 'state val e_iter : (expr -> unit) -> expr -> unit - -(* -------------------------------------------------------------------- *) diff --git a/src/ecTyping.ml b/src/ecTyping.ml index 3a4c27a778..9e2de616ad 100644 --- a/src/ecTyping.ml +++ b/src/ecTyping.ml @@ -22,7 +22,7 @@ module NormMp = EcEnv.NormMp (* -------------------------------------------------------------------- *) type opmatch = [ - | `Op of EcPath.path * EcTypes.ty list + | `Op of EcPath.path * EcTypes.etyarg list | `Lc of EcIdent.t | `Var of EcTypes.prog_var | `Proj of EcTypes.prog_var * EcMemory.proj_arg @@ -127,7 +127,7 @@ type goal_shape_error = type tyerror = | UniVarNotAllowed -| FreeTypeVariables +| FreeUniVariables of EcUnify.uniflags | TypeVarNotAllowed | OnlyMonoTypeAllowed of symbol option | NoConcreteAnonParams @@ -154,6 +154,7 @@ type tyerror = | NonUnitFunWithoutReturn | TypeMismatch of (ty * ty) * (ty * ty) | TypeClassMismatch +| TypeClassAmbiguous of typeclass * EcPath.path list | TypeModMismatch of mpath * module_type * tymod_cnv_failure | NotAFunction | NotAnInductive @@ -187,6 +188,8 @@ type tyerror = | ModuleNotAbstract of symbol | ProcedureUnbounded of symbol * symbol | LvMapOnNonAssign +| TCArgsCountMismatch of qsymbol * ty_params * ty list +| CannotInferTC of ty * typeclass | NoDefaultMemRestr | ProcAssign of qsymbol | PositiveShouldBeBeforeNegative @@ -211,13 +214,18 @@ module UE = EcUnify.UniEnv let unify_or_fail (env : EcEnv.env) ue loc ~expct:ty1 ty2 = try EcUnify.unify env ue ty1 ty2 - with EcUnify.UnificationFailure pb -> + with + | EcUnify.AmbiguousTcInstance (tc, paths) -> + tyerror loc env (TypeClassAmbiguous (tc, paths)) + | EcUnify.UnificationFailure pb -> match pb with | `TyUni (t1, t2)-> let uidmap = UE.assubst ue in let tyinst = ty_subst (Tuni.subst uidmap) in tyerror loc env (TypeMismatch ((tyinst ty1, tyinst ty2), (tyinst t1, tyinst t2))) + | `TcCtt _ | `TcTw _ -> (* FIXME: proper error message *) + tyerror loc env TypeClassMismatch (* -------------------------------------------------------------------- *) let add_glob (m:Sx.t) (x:prog_var) : Sx.t = @@ -341,7 +349,7 @@ module OpSelect = struct type opsel = [ | `Pv of EcMemory.memory option * pvsel - | `Op of (EcPath.path * ty list) + | `Op of (EcPath.path * etyarg list) | `Lc of EcIdent.ident | `Nt of EcUnify.sbody ] @@ -369,7 +377,7 @@ let gen_select_op let fpv me (pv, ty, ue) : OpSelect.gopsel = (`Pv (me, pv), ty, ue, (pv :> opmatch)) - and fop (op, ty, ue, bd) : OpSelect.gopsel= + and fop ((op : path * etyarg list), ty, ue, bd) : OpSelect.gopsel = match bd with | None -> (`Op op, ty, ue, (`Op op :> opmatch)) | Some bd -> (`Nt bd, ty, ue, (`Op op :> opmatch)) @@ -389,12 +397,70 @@ let gen_select_op and by_current ((p, _), _, _, _) = EcPath.isprefix ~prefix:(oget (EcPath.prefix p)) ~path:(EcEnv.root env) - and by_tc ((p, _), _, _, _) = - match oget (EcEnv.Op.by_path_opt p env) with - | { op_kind = OB_oper (Some OP_TC) } -> false - | _ -> true - in + (* Subsumption filter on the candidate list: drop a TC-op candidate + when, for the resolved carrier, either + (a) [tc_reduce] succeeds and yields a head op already among the + non-TC candidates (the TC is just an indirection to it), or + (b) the carrier is concrete but no instance applies (no chance of + the TC ever firing on this carrier), and a non-TC candidate + exists. + When the carrier is still a univar (or no non-TC candidate exists), + keep the TC op. The typer's existing [MultipleOpMatch] retry then + re-types with a fresh argument univar, and downstream context + disambiguates via expected-type unification. *) + let drop_subsumed_tc ops = + let is_tc_op p = + match (EcEnv.Op.by_path_opt p env) with + | Some { op_kind = OB_oper (Some (OP_TC _)) } -> true + | _ -> false in + let concrete_paths = + List.filter_map + (fun ((p, _), _, _, _) -> if is_tc_op p then None else Some p) + ops in + if concrete_paths = [] then ops + else + let carrier_is_concrete (etyargs : etyarg list) (subue : EcUnify.unienv) = + (* "Concrete" = a [Tconstr] whose declaration is NOT an abstract + type-with-TC. Class-declaration self-types and section-bound + abstract types are also [Tconstr] but should still be + treated as abstract here so TC-op candidates aren't pruned. *) + match List.rev etyargs with + | [] -> false + | (ty, _) :: _ -> + let ty = ty_subst (Tuni.subst (EcUnify.UniEnv.assubst subue)) ty in + match ty.ty_node with + | Tconstr (p, _) -> begin + match EcEnv.Ty.by_path_opt p env with + (* [Abstract (_ :: _)]: section-bound or class-self type + with TC bounds — TC ops on it may still be viable + via the bounds. [Abstract []]: primitive (e.g. [int]) — + TC viability requires a registered instance, treat as + concrete here so [drop_subsumed_tc] can dedup TC ops + against non-TC candidates with the same effective + head. *) + | Some { tyd_type = `Abstract (_ :: _) } -> false + | _ -> true + end + | _ -> false in + List.filter (fun ((p, etyargs), _, subue, _) -> + if not (is_tc_op p) then true + else + match EcEnv.Op.tc_reduce env p etyargs with + | red -> begin + let red_head = + match red.f_node with + | Fop (p', _) -> Some p' + | Fapp ({ f_node = Fop (p', _) }, _) -> Some p' + | _ -> None in + match red_head with + | None -> true + | Some p' -> + not (List.exists (EcPath.p_equal p') concrete_paths) + end + | exception EcEnv.NotReducible -> + not (carrier_is_concrete etyargs subue) + ) ops in let locals () : OpSelect.gopsel list = if Option.is_none tvi then @@ -404,11 +470,131 @@ let gen_select_op |> Option.to_list else [] in + (* Drop notation/abbrev candidates ([OB_nott]) when a TC-op + candidate sharing the same basename is also present. The + [TcMonoid] family ships generic notation abbrevs like + [abbrev ( * ) ['a <: mulmonoid] (x y) = (+)<:'a> x y] that, when + applied at a [comring] carrier, expand to exactly the same + [Fop] as comring's own [( * )] TC operator. The two are + interchangeable but [select_op] returns both, leaving the user + with a [MultipleOpMatch] error in any external file that + imports the algebra hierarchy. Inside the defining file the + [by_current] filter drops the abbrev (different prefix), but + across files we need a structural rule. *) + let drop_shadowed_notation ops = + let has_tc_op_with_name n = + List.exists (fun ((p, _), _, _, _) -> + match EcEnv.Op.by_path_opt p env with + | Some { op_kind = OB_oper (Some (OP_TC _)) } -> + EcPath.basename p = n + | _ -> false) ops in + List.filter (fun ((p, _), _, _, _) -> + match EcEnv.Op.by_path_opt p env with + | Some { op_kind = OB_nott _ } -> + not (has_tc_op_with_name (EcPath.basename p)) + | _ -> true) ops in + + (* [drop_subsumed_tc] classifies candidates by the [op_kind] of their + declared path — but after typing, abbrev candidates are inlined and + their bodies' heads are what actually appear in the elaborated + term. So an abbrev whose body is itself a TC-op invocation + (e.g. [TcMonoid.( * ) ['a <: mulmonoid] (x y) = (+)<:'a> x y]) + escapes [drop_subsumed_tc]'s filter even though, post-inline, it's + a TC-op-headed term that may reduce to the same head as another + candidate. + + This pass closes that gap: for each candidate, compute its + post-inline body head; collect the non-TC heads as + [concrete_heads]; then drop any candidate whose post-inline head + is a TC op that [tc_reduce]s to a head already in [concrete_heads]. + Mirror image of [drop_subsumed_tc] but operating on body heads + rather than declared op_kind, catching the abbrev-to-TC-op case + that the pre-inline classification misses. *) + let drop_subsumed_by_post_inline_head ops = + let is_tc_op p = + match EcEnv.Op.by_path_opt p env with + | Some { op_kind = OB_oper (Some (OP_TC _)) } -> true + | _ -> false in + let body_head ((path, etyargs), _, _, bd) = + match bd with + | None -> Some (path, etyargs) + | Some bd_lazy -> + let _, body = Lazy.force bd_lazy in + let head, _ = EcTypes.destr_app body in + (match head.e_node with + | Eop (p, tys) -> Some (p, tys) + | _ -> None) in + let concrete_heads = + List.filter_map (fun cand -> + match body_head cand with + | Some (p, _) when not (is_tc_op p) -> Some p + | _ -> None) ops in + if concrete_heads = [] then ops + else + List.filter (fun cand -> + match body_head cand with + | Some (p, etyargs) when is_tc_op p -> begin + match EcEnv.Op.tc_reduce env p etyargs with + | red -> + let red_head = + match red.f_node with + | Fop (p', _) -> Some p' + | Fapp ({ f_node = Fop (p', _) }, _) -> Some p' + | _ -> None in + (match red_head with + | None -> true + | Some p' -> not (List.exists (EcPath.p_equal p') concrete_heads)) + | exception EcEnv.NotReducible -> true + end + | _ -> true) ops in + + (* Drop a TC-bounded notation candidate (an abbrev whose tparams have + non-empty TC bounds, e.g. [TcRing.(-) ['a <: addgroup] (x y) = …]) + when a same-basename candidate with no TC-bounded tparams (e.g. the + monomorphic [Int.(-)] abbrev) is also present. The TC-bounded form, + when instantiated at a carrier that also has a non-TC alternative, + unfolds to the same operator, so [select_op] returning both leaves + a spurious [MultipleOpMatch]. Mirror image of [drop_subsumed_tc] + for the abbrev side. *) + let drop_tc_bounded_notation ops = + let is_tc_bounded_nott p = + match EcEnv.Op.by_path_opt p env with + | Some { op_kind = OB_nott _; op_tparams = tparams } -> + List.exists (fun (_, tcs) -> tcs <> []) tparams + | _ -> false in + let has_unbounded_with_name n = + List.exists (fun ((p, _), _, _, _) -> + EcPath.basename p = n + && match EcEnv.Op.by_path_opt p env with + | Some { op_tparams = tparams } -> + not (List.exists (fun (_, tcs) -> tcs <> []) tparams) + | None -> false) ops in + List.filter (fun ((p, _), _, _, _) -> + if is_tc_bounded_nott p + then not (has_unbounded_with_name (EcPath.basename p)) + else true) ops in + let ops () : OpSelect.gopsel list = - let ops = EcUnify.select_op ~filter:ue_filter tvi env name ue psig in + let ops = EcUnify.select_op ~filter:ue_filter ?retty:(snd psig) tvi env name ue (fst psig) in let ops = opsc |> ofold (fun opsc -> List.mbfilter (by_scope opsc)) ops in let ops = match List.mbfilter by_current ops with [] -> ops | ops -> ops in - let ops = match List.mbfilter by_tc ops with [] -> ops | ops -> ops in + (* [drop_subsumed_tc] runs first because it can ELIMINATE TC + candidates that won't apply (concrete carrier with no + registered instance). Then [drop_shadowed_notation] only + fires when a TC op is actually viable, leaving abbrevs alone + in the [Int.( <= )] / int-args case. *) + let ops = + let pruned = drop_subsumed_tc ops in + if pruned = [] then ops else pruned in + let ops = + let pruned = drop_shadowed_notation ops in + if pruned = [] then ops else pruned in + let ops = + let pruned = drop_tc_bounded_notation ops in + if pruned = [] then ops else pruned in + let ops = + let pruned = drop_subsumed_by_post_inline_head ops in + if pruned = [] then ops else pruned in (List.map fop ops) and pvs () : OpSelect.gopsel list = @@ -447,7 +633,7 @@ let select_form_op env mode ~forcepv opsc name ue tvi psig = (* -------------------------------------------------------------------- *) let select_proj env opsc name ue tvi recty = let filter = (fun _ op -> EcDecl.is_proj op) in - let ops = EcUnify.select_op ~filter tvi env name ue ([recty], None) in + let ops = EcUnify.select_op ~filter tvi env name ue [recty] in let ops = List.map (fun (p, ty, ue, _) -> (p, ty, ue)) ops in match ops, opsc with @@ -483,26 +669,6 @@ let tp_uni = { tp_uni = true ; tp_tvar = false; } (* params/local vars. *) (* -------------------------------------------------------------------- *) type ismap = (instr list) Mstr.t -(* -------------------------------------------------------------------- *) -let transtcs (env : EcEnv.env) tcs = - let for1 tc = - match EcEnv.TypeClass.lookup_opt (unloc tc) env with - | None -> tyerror tc.pl_loc env (UnknownTypeClass (unloc tc)) - | Some (p, _) -> p (* FIXME: TC HOOK *) - in - Sp.of_list (List.map for1 tcs) - -(* -------------------------------------------------------------------- *) -let transtyvars (env : EcEnv.env) (loc, tparams) = - let tparams = tparams |> omap - (fun tparams -> - let for1 ({ pl_desc = x }) = (EcIdent.create x) in - if not (List.is_unique (List.map unloc tparams)) then - tyerror loc env DuplicatedTyVar; - List.map for1 tparams) - in - EcUnify.UniEnv.create tparams - (* -------------------------------------------------------------------- *) exception TymodCnvFailure of tymod_cnv_failure @@ -991,7 +1157,7 @@ let trans_msymbol env msymb = (m,mt) (* -------------------------------------------------------------------- *) -let rec transty (tp : typolicy) (env : EcEnv.env) ue ty = +let rec transty (tp : typolicy) (env : EcEnv.env) (ue : EcUnify.unienv) (ty : pty) : ty = match ty.pl_desc with | PTunivar -> if tp.tp_uni @@ -1050,6 +1216,46 @@ let transty_for_decl env ty = let ue = UE.create (Some []) in transty tp_nothing env ue ty +(* -------------------------------------------------------------------- *) +let transtc (env : EcEnv.env) ue ((tc_name, args) : ptcparam) : typeclass = + match EcEnv.TypeClass.lookup_opt (unloc tc_name) env with + | None -> + tyerror (loc tc_name) env (UnknownTypeClass (unloc tc_name)) + + | Some (p, decl) -> + let args = List.map (transty tp_tydecl env ue) args in + + if List.length decl.tc_tparams <> List.length args then begin + tyerror (loc tc_name) env + (TCArgsCountMismatch (unloc tc_name, decl.tc_tparams, args)); + end; + + let tvi = EcUnify.UniEnv.opentvi ue decl.tc_tparams None in + + List.iter2 + (fun (ty, _) aty -> unify_or_fail env ue (loc tc_name) ~expct:ty aty) + tvi.args args; + + { tc_name = p; tc_args = tvi.args; } + +(* -------------------------------------------------------------------- *) +let transtyvars (env : EcEnv.env) (loc, (tparams : ptyparams option)) = + match tparams with + | None -> + UE.create None + + | Some tparams -> + let ue = UE.create (Some []) in + + let for1 ({ pl_desc = x }, tc) = + let x = EcIdent.create x in + let tc = List.map (transtc env ue) tc in + UE.push (x, tc) ue in + if not (List.is_unique (List.map (fst |- unloc) tparams)) then + tyerror loc env DuplicatedTyVar; + List.iter for1 tparams; + ue + (* -------------------------------------------------------------------- *) let transpattern1 env ue (p : EcParsetree.plpattern) = match p.pl_desc with @@ -1075,7 +1281,7 @@ let transpattern1 env ue (p : EcParsetree.plpattern) = let fields = let for1 (name, v) = let filter = fun _ op -> EcDecl.is_proj op in - let fds = EcUnify.select_op ~filter None env (unloc name) ue ([], None) in + let fds = EcUnify.select_op ~filter None env (unloc name) ue [] in match List.ohead fds with | None -> let exn = UnknownRecFieldName (unloc name) in @@ -1099,8 +1305,9 @@ let transpattern1 env ue (p : EcParsetree.plpattern) = let recty = oget (EcEnv.Ty.by_path_opt recp env) in let rec_ = snd (oget (EcDecl.tydecl_as_record recty)) in - let reccty = tconstr recp (List.map tvar recty.tyd_params) in - let reccty, rectvi = EcUnify.UniEnv.openty ue recty.tyd_params None reccty in + let reccty = tconstr_tc recp (etyargs_of_tparams recty.tyd_params) in + let reccty, recopnd = EcUnify.UniEnv.openty ue recty.tyd_params None reccty in + let fields = List.fold_left (fun map (((_, idx), _, _) as field) -> @@ -1120,8 +1327,9 @@ let transpattern1 env ue (p : EcParsetree.plpattern) = let pty = EcUnify.UniEnv.fresh ue in let fty = snd (List.nth rec_ i) in let fty, _ = - EcUnify.UniEnv.openty ue recty.tyd_params - (Some (EcUnify.TVIunamed rectvi)) fty + EcUnify.UniEnv.openty + ue recty.tyd_params + (Some (EcUnify.tvi_unamed recopnd.args)) fty in (try EcUnify.unify env ue pty fty with EcUnify.UnificationFailure _ -> assert false); @@ -1154,7 +1362,9 @@ let transpattern env ue (p : EcParsetree.plpattern) = let transtvi env ue tvi = match tvi.pl_desc with | TVIunamed lt -> - EcUnify.TVIunamed (List.map (transty tp_relax env ue) lt) + let tys = List.map (transty tp_relax env ue) lt in + let tvi = List.map (fun ty -> (Some ty, None)) tys in + EcUnify.TVIunamed tvi | TVInamed lst -> let add locals (s, t) = @@ -1163,8 +1373,9 @@ let transtvi env ue tvi = (s, transty tp_relax env ue t) :: locals in - let lst = List.fold_left add [] lst in - EcUnify.TVInamed (List.rev_map (fun (s,t) -> unloc s, t) lst) + let tvi = List.fold_left add [] lst in + let tvi = List.map (snd_map (fun ty -> (Some ty, None))) tvi in + EcUnify.TVInamed (List.rev_map (fun (s, t) -> unloc s, t) tvi) let rec destr_tfun env ue tf = match tf.ty_node with @@ -1215,7 +1426,7 @@ let trans_record env ue (subtt, proj) (loc, b, fields) = let for1 rf = let filter = fun _ op -> EcDecl.is_proj op in let tvi = rf.rf_tvi |> omap (transtvi env ue) in - let fds = EcUnify.select_op ~filter tvi env (unloc rf.rf_name) ue ([], None) in + let fds = EcUnify.select_op ~filter tvi env (unloc rf.rf_name) ue [] in match List.ohead fds with | None -> let exn = UnknownRecFieldName (unloc rf.rf_name) in @@ -1239,9 +1450,8 @@ let trans_record env ue (subtt, proj) (loc, b, fields) = let recty = oget (EcEnv.Ty.by_path_opt recp env) in let rec_ = snd (oget (EcDecl.tydecl_as_record recty)) in - let reccty = tconstr recp (List.map tvar recty.tyd_params) in - let reccty, rtvi = EcUnify.UniEnv.openty ue recty.tyd_params None reccty in - let tysopn = Tvar.init recty.tyd_params rtvi in + let reccty = tconstr_tc recp (EcDecl.etyargs_of_tparams recty.tyd_params) in + let reccty, ropnd = EcUnify.UniEnv.openty ue recty.tyd_params None reccty in let fields = List.fold_left @@ -1270,7 +1480,7 @@ let trans_record env ue (subtt, proj) (loc, b, fields) = | None -> match dflrec with | None -> tyerror loc env (MissingRecField name) - | Some _ -> `Dfl (Tvar.subst tysopn rty, name) + | Some _ -> `Dfl (Tvar.subst ropnd.subst rty, name) in List.mapi (fun i (name, rty) -> get_field i name rty) rec_ in @@ -1286,7 +1496,7 @@ let trans_record env ue (subtt, proj) (loc, b, fields) = | `Dfl (rty, name) -> let nm = oget (EcPath.prefix recp) in - (proj (nm, name, (rtvi, reccty), rty, oget dflrec), rty) + (proj (nm, name, (ropnd.args, reccty), rty, oget dflrec), rty) in List.map for1 fields @@ -1297,7 +1507,7 @@ let trans_record env ue (subtt, proj) (loc, b, fields) = (EcPath.prefix recp) (Printf.sprintf "mk_%s" (EcPath.basename recp)) in - (ctor, fields, (rtvi, reccty)) + (ctor, fields, (ropnd.args, reccty)) (* -------------------------------------------------------------------- *) let trans_branch ~loc env ue gindty ((pb, body) : ppattern * _) = @@ -1306,7 +1516,7 @@ let trans_branch ~loc env ue gindty ((pb, body) : ppattern * _) = | PPApp ((cname, tvi), cargs) -> let filter = fun _ op -> EcDecl.is_ctor op in let tvi = tvi |> omap (transtvi env ue) in - let cts = EcUnify.select_op ~filter tvi env (unloc cname) ue ([], None) in + let cts = EcUnify.select_op ~filter tvi env (unloc cname) ue [] in match cts with | [] -> @@ -1339,7 +1549,7 @@ let trans_branch ~loc env ue gindty ((pb, body) : ppattern * _) = EcUnify.UniEnv.restore ~src:subue ~dst:ue; let ctorty = - let tvi = Some (EcUnify.TVIunamed tvi) in + let tvi = Some (EcUnify.TVIunamed (List.map (fun (ty, w) -> (Some ty, Some (List.map (fun x -> Some x) w))) tvi)) in fst (EcUnify.UniEnv.opentys ue indty.tyd_params tvi ctorty) in let pty = EcUnify.UniEnv.fresh ue in @@ -1403,7 +1613,7 @@ let trans_branch_exn env ue ((pb, body) : ppattern * _) = | PPApp ((cname, tvi), cargs) -> let filter = fun _ op -> EcDecl.is_exception op in let tvi = tvi |> omap (transtvi env ue) in - let cts = EcUnify.select_op ~filter tvi env (unloc cname) ue ([], None) in + let cts = EcUnify.select_op ~filter tvi env (unloc cname) ue [] in match cts with | [] -> @@ -1452,7 +1662,6 @@ let trans_match_exn ~loc env ue pbs = List.fold_left fill_branch Mop.empty pbs (*-------------------------------------------------------------------- *) - let var_or_proj fvar fproj pv ty = match pv with | `Var pv -> fvar pv ty @@ -1674,7 +1883,7 @@ let form_of_opselect in (f_lambda flam (Fsubst.f_subst subst body), args) | (`Op _ | `Lc _ | `Pv _) as sel -> let op = match sel with - | `Op (p, tys) -> f_op p tys ty + | `Op (p, tys) -> f_op_tc p tys ty | `Lc id -> f_local id ty | `Pv (me, pv) -> var_or_proj (fun x ty -> (f_pvar x ty (oget me)).inv) f_proj pv ty @@ -1692,7 +1901,7 @@ let form_of_opselect * - e is the index to update * - ty is the type of the value [x] *) -type lvmap = (path * ty list) * prog_var * expr * ty +type lvmap = (path * etyarg list) * prog_var * expr * ty type lVAl = | Lval of lvalue @@ -1702,7 +1911,7 @@ let i_asgn_lv (_loc : EcLocation.t) (_env : EcEnv.env) lv e = match lv with | Lval lv -> i_asgn (lv, e) | LvMap ((op,tys), x, ei, ty) -> - let op = e_op op tys (toarrow [ty; ei.e_ty; e.e_ty] ty) in + let op = e_op_tc op tys (toarrow [ty; ei.e_ty; e.e_ty] ty) in i_asgn (LvVar (x,ty), e_app op [e_var x ty; ei; e] ty) let i_rnd_lv loc env lv e = @@ -2251,7 +2460,7 @@ and transmod_body ~attop (env : EcEnv.env) x params (me:pmodule_expr) = let asgn = EcModules.lv_of_list pvs |> omap (fun lv -> let rty = ttuple (List.snd p) in let proj = EcInductive.datatype_proj_path typ cn in - let proj = e_op proj (List.snd tyinst) (tfun e.e_ty (toption rty)) in + let proj = e_op_tc proj (List.snd tyinst) (tfun e.e_ty (toption rty)) in let proj = e_app proj [e] (toption rty) in let proj = e_oget proj rty in i_asgn (lv, proj)) @@ -2671,7 +2880,7 @@ and fundef_add_symbol env (memenv : memenv) xtys : memenv = and fundef_check_type subst_uni env os (ty, loc) = let ty = subst_uni ty in - if not (EcUid.Suid.is_empty (Tuni.fv ty)) then + if not (TyUni.Suid.is_empty (Tuni.fv ty)) then tyerror loc env (OnlyMonoTypeAllowed os); ty @@ -2795,7 +3004,7 @@ and transinstr match (EcEnv.ty_hnorm ety env).ty_node with | Tconstr (indp, _) -> begin match EcEnv.Ty.by_path indp env with - | { tyd_type = Datatype dt } -> + | { tyd_type = `Datatype dt } -> Some (indp, dt) | _ -> None end @@ -3406,7 +3615,7 @@ and trans_form_or_pattern env mode ?mv ?ps ue pf tt = match (EcEnv.ty_hnorm cfty env).ty_node with | Tconstr (indp, _) -> begin match EcEnv.Ty.by_path indp env with - | { tyd_type = Datatype dt } -> + | { tyd_type = `Datatype dt } -> Some (indp, dt) | _ -> None end @@ -3465,12 +3674,12 @@ and trans_form_or_pattern env mode ?mv ?ps ue pf tt = let (ctor, fields, (rtvi, reccty)) = let proj (recp, name, (rtvi, reccty), pty, arg) = let proj = EcPath.pqname recp name in - let proj = f_op proj rtvi (tfun reccty pty) in + let proj = f_op_tc proj rtvi (tfun reccty pty) in f_app proj [arg] pty in trans_record env ue ((fun f -> let f = transf env f in (f, f.f_ty)), proj) (f.pl_loc, b, fields) in - let ctor = f_op ctor rtvi (toarrow (List.map snd fields) reccty) in + let ctor = f_op_tc ctor rtvi (toarrow (List.map snd fields) reccty) in f_app ctor (List.map fst fields) reccty | PFproj (subf, x) -> begin @@ -3488,7 +3697,7 @@ and trans_form_or_pattern env mode ?mv ?ps ue pf tt = let rty = EcUnify.UniEnv.fresh ue in (try EcUnify.unify env ue (tfun subf.f_ty rty) pty with EcUnify.UnificationFailure _ -> assert false); - f_app (f_op op tvi pty) [subf] rty + f_app (f_op_tc op tvi pty) [subf] rty end | PFproji (psubf, i) -> begin @@ -3814,15 +4023,21 @@ and trans_dcodegap1 ?(memory : memory option) (env : EcEnv.env) (p : pcodegap1 d (* -------------------------------------------------------------------- *) let get_instances (tvi, bty) env = - let inst = List.pmap - (function - | (_, (`Ring _ | `Field _)) as x -> Some x - | _ -> None) - (EcEnv.TypeClass.get_instances env) in + let inst = + let filter ((_, tci) : path option * EcTheory.tcinstance) = + match tci with + | EcTheory.{ + tci_params = []; + tci_instance = (`Ring _ | `Field _) as bd + } -> Some (tci.tci_type, bd) + + | _ -> None + + in List.pmap filter (EcEnv.TcInstance.get_all env) in - List.pmap (fun ((typ, gty), cr) -> + List.pmap (fun (gty, cr) -> let ue = EcUnify.UniEnv.create (Some tvi) in - let (gty, _typ) = EcUnify.UniEnv.openty ue typ None gty in + let (gty, _) = EcUnify.UniEnv.openty ue [] None gty in try EcUnify.unify env ue bty gty; let ts = Tuni.subst (UE.close ue) in diff --git a/src/ecTyping.mli b/src/ecTyping.mli index fa2514dcbd..5bea080f06 100644 --- a/src/ecTyping.mli +++ b/src/ecTyping.mli @@ -14,7 +14,7 @@ open EcMatching.Position (* -------------------------------------------------------------------- *) type opmatch = [ - | `Op of EcPath.path * EcTypes.ty list + | `Op of EcPath.path * EcTypes.etyarg list | `Lc of EcIdent.t | `Var of EcTypes.prog_var | `Proj of EcTypes.prog_var * EcMemory.proj_arg @@ -23,7 +23,7 @@ type opmatch = [ type 'a mismatch_sets = [`Eq of 'a * 'a | `Sub of 'a ] -type 'a suboreq = [`Eq of 'a | `Sub of 'a ] +type 'a suboreq = [`Eq of 'a | `Sub of 'a ] type mismatch_funsig = | MF_targs of ty * ty (* expected, got *) @@ -120,7 +120,7 @@ type goal_shape_error = type tyerror = | UniVarNotAllowed -| FreeTypeVariables +| FreeUniVariables of EcUnify.uniflags | TypeVarNotAllowed | OnlyMonoTypeAllowed of symbol option | NoConcreteAnonParams @@ -147,6 +147,7 @@ type tyerror = | NonUnitFunWithoutReturn | TypeMismatch of (ty * ty) * (ty * ty) | TypeClassMismatch +| TypeClassAmbiguous of typeclass * EcPath.path list | TypeModMismatch of mpath * module_type * tymod_cnv_failure | NotAFunction | NotAnInductive @@ -180,6 +181,8 @@ type tyerror = | ModuleNotAbstract of symbol | ProcedureUnbounded of symbol * symbol | LvMapOnNonAssign +| TCArgsCountMismatch of qsymbol * ty_params * ty list +| CannotInferTC of ty * typeclass | NoDefaultMemRestr | ProcAssign of qsymbol | PositiveShouldBeBeforeNegative @@ -202,6 +205,9 @@ val tp_relax : typolicy val tp_nothing : typolicy (* -------------------------------------------------------------------- *) +val transtc: + env -> EcUnify.unienv -> ptcparam -> typeclass + val transtyvars: env -> (EcLocation.t * ptyparams option) -> EcUnify.unienv diff --git a/src/ecUid.ml b/src/ecUid.ml index 6e9124b62c..8b4643cfd0 100644 --- a/src/ecUid.ml +++ b/src/ecUid.ml @@ -6,37 +6,84 @@ open EcSymbols (* -------------------------------------------------------------------- *) let unique () = Oo.id (object end) +(* -------------------------------------------------------------------- *) +module type ICore = sig + type uid + + (* ------------------------------------------------------------------ *) + val unique : unit -> uid + val uid_equal : uid -> uid -> bool + val uid_compare : uid -> uid -> int + + (* ------------------------------------------------------------------ *) + module Muid : Map.S with type key = uid + module Suid : Set.S with module M = Map.MakeBase(Muid) + + (* ------------------------------------------------------------------ *) + module SMap : sig + type uidmap + + val create : unit -> uidmap + val lookup : uidmap -> symbol -> uid option + val forsym : uidmap -> symbol -> uid + val pp_uid : Format.formatter -> uid -> unit + end +end + (* -------------------------------------------------------------------- *) type uid = int -type uidmap = { - (*---*) um_tbl : (symbol, uid) Hashtbl.t; - mutable um_uid : int; -} +(* -------------------------------------------------------------------- *) +module Core : ICore with type uid := uid = struct + (* ------------------------------------------------------------------ *) + let unique () : uid = + unique () + + let uid_equal x y = x == y + let uid_compare x y = x - y + + (* ------------------------------------------------------------------ *) + module Muid = Mint + module Suid = Set.MakeOfMap(Muid) + + (* ------------------------------------------------------------------ *) + module SMap = struct + type uidmap = { + (*---*) um_tbl : (symbol, uid) Hashtbl.t; + mutable um_uid : int; + } + + let create () = + { um_tbl = Hashtbl.create 0; + um_uid = 0; } -let create () = - { um_tbl = Hashtbl.create 0; - um_uid = 0; } + let lookup (um : uidmap) (x : symbol) = + try Some (Hashtbl.find um.um_tbl x) + with Not_found -> None -let lookup (um : uidmap) (x : symbol) = - try Some (Hashtbl.find um.um_tbl x) - with Not_found -> None + let forsym (um : uidmap) (x : symbol) = + match lookup um x with + | Some uid -> uid + | None -> + let uid = um.um_uid in + um.um_uid <- um.um_uid + 1; + Hashtbl.add um.um_tbl x uid; + uid -let forsym (um : uidmap) (x : symbol) = - match lookup um x with - | Some uid -> uid - | None -> - let uid = um.um_uid in - um.um_uid <- um.um_uid + 1; - Hashtbl.add um.um_tbl x uid; - uid + let pp_uid fmt u = + Format.fprintf fmt "#%d" u + end +end (* -------------------------------------------------------------------- *) -let uid_equal x y = x == y -let uid_compare x y = x - y +module CoreGen() : ICore with type uid = private uid = struct + type nonrec uid = uid + + include Core +end -module Muid = Mint -module Suid = Set.MakeOfMap(Muid) +(* -------------------------------------------------------------------- *) +include Core (* -------------------------------------------------------------------- *) module NameGen = struct diff --git a/src/ecUid.mli b/src/ecUid.mli index 885bcbd99f..429132eef9 100644 --- a/src/ecUid.mli +++ b/src/ecUid.mli @@ -5,20 +5,37 @@ open EcSymbols (* -------------------------------------------------------------------- *) val unique : unit -> int +module type ICore = sig + type uid + + (* ------------------------------------------------------------------ *) + val unique : unit -> uid + val uid_equal : uid -> uid -> bool + val uid_compare : uid -> uid -> int + + (* ------------------------------------------------------------------ *) + module Muid : Map.S with type key = uid + module Suid : Set.S with module M = Map.MakeBase(Muid) + + (* ------------------------------------------------------------------ *) + module SMap : sig + type uidmap + + val create : unit -> uidmap + val lookup : uidmap -> symbol -> uid option + val forsym : uidmap -> symbol -> uid + val pp_uid : Format.formatter -> uid -> unit + end +end + (* -------------------------------------------------------------------- *) type uid = int -type uidmap - -val create : unit -> uidmap -val lookup : uidmap -> symbol -> uid option -val forsym : uidmap -> symbol -> uid (* -------------------------------------------------------------------- *) -val uid_equal : uid -> uid -> bool -val uid_compare : uid -> uid -> int +include ICore with type uid := uid -module Muid : Map.S with type key = uid -module Suid : Set.S with module M = Map.MakeBase(Muid) +(* -------------------------------------------------------------------- *) +module CoreGen() : ICore with type uid = private uid (* -------------------------------------------------------------------- *) module NameGen : sig diff --git a/src/ecUnify.ml b/src/ecUnify.ml index 45cc667535..a978a44155 100644 --- a/src/ecUnify.ml +++ b/src/ecUnify.ml @@ -3,204 +3,813 @@ open EcSymbols open EcIdent open EcMaps open EcUtils -open EcUid open EcAst open EcTypes open EcCoreSubst open EcDecl +open EcTheory module Sp = EcPath.Sp -module TC = EcTypeClass -(* -------------------------------------------------------------------- *) -type pb = [ `TyUni of ty * ty ] +(* ==================================================================== *) +type problem = [ + | `TyUni of ty * ty + | `TcTw of tcwitness * tcwitness + | `TcCtt of tcuni * ty * typeclass +] -exception UnificationFailure of pb -exception UninstantiateUni +(* ==================================================================== *) +type uniflags = { tyvars: bool; tcvars: bool; } -(* -------------------------------------------------------------------- *) -module UFArgs = struct - module I = struct - type t = uid +exception UnificationFailure of problem +exception UninstanciateUni of uniflags +exception AmbiguousTcInstance of typeclass * EcPath.path list - let equal = uid_equal - let compare = uid_compare - end +(* ==================================================================== *) +module Unify = struct + module UFArgs = struct + module I = struct + type t = tyuni + + let equal = TyUni.uid_equal + let compare = TyUni.uid_compare + end + + module D = struct + type data = ty option + type effects = problem list - module D = struct - type data = ty option - type effects = pb list + let default : data = + None - let default : data = - None + let isvoid (x : data) = + Option.is_none x - let isvoid (x : data) = - Option.is_none x + let noeffects : effects = [] - let noeffects : effects = - [] + let union (ty1 : data) (ty2 : data) : data * effects = + let ty, cts = + match ty1, ty2 with + | None, None -> + (None, []) + | Some ty1, Some ty2 -> + Some ty1, [(ty1, ty2)] - let union (d1 : data) (d2 : data) = - match d1, d2 with - | None, None -> - (None, []) + | None, Some ty | Some ty, None -> + Some ty, [] in - | Some ty1, Some ty2 -> - (Some ty1, [`TyUni (ty1, ty2)]) + let cts = List.map (fun x -> `TyUni x) cts in - | None , Some ty - | Some ty, None -> - (Some ty, []) + ty, (cts :> effects) + end end -end -module UF = EcUFind.Make(UFArgs.I)(UFArgs.D) + (* ------------------------------------------------------------------ *) + module UF = EcUFind.Make(UFArgs.I)(UFArgs.D) + + (* ------------------------------------------------------------------ *) + type ucore = { + uf : UF.t; + tvtc : typeclass list Mid.t; + tcenv : tcenv; + } + + and tcenv = { + (* Map from UID to TC problems. The optional [symbol] is the + op name when the problem was created at op-typing site, used + to disambiguate parent-DAG paths whose cumulative renaming + would clobber that name. *) + problems : (ty * typeclass * EcSymbols.symbol option) TcUni.Muid.t; + + (* Map from univars to TC problems that depend on them. *) + byunivar : TcUni.Suid.t TyUni.Muid.t; + + (* Map from problems UID to type-class instance witness *) + resolution : tcwitness TcUni.Muid.t + } + + (* ------------------------------------------------------------------ *) + let tcenv_empty : tcenv = + { problems = TcUni.Muid.empty + ; byunivar = TyUni.Muid.empty + ; resolution = TcUni.Muid.empty } + + (* ------------------------------------------------------------------ *) + let tcenv_closed (tcenv : tcenv) : bool = + TcUni.Muid.cardinal tcenv.resolution + = TcUni.Muid.cardinal tcenv.problems + + (* ------------------------------------------------------------------ *) + let create_tcproblem + ?(op_name : EcSymbols.symbol option) + (tcenv : tcenv) + (ty : ty) + (tcw : typeclass * tcwitness option) + : tcenv * tcwitness + = + let tc, tw = tcw in + let uid = TcUni.unique () in + let deps = Tuni.univars ty in + + let tcenv = { + problems = TcUni.Muid.add uid (ty, tc, op_name) tcenv.problems; + byunivar = TyUni.Suid.fold (fun duni byunivar -> + TyUni.Muid.change (fun pbs -> + Some (TcUni.Suid.add uid (Option.value ~default:TcUni.Suid.empty pbs)) + ) duni byunivar + ) deps tcenv.byunivar; + resolution = + ofold + (fun tw map -> TcUni.Muid.add uid tw map) + tcenv.resolution tw; + } in + + tcenv, TCIUni (uid, []) + + (* ------------------------------------------------------------------ *) + let initial_ucore ?(tvtc = Mid.empty) () : ucore = + { uf = UF.initial; tcenv = tcenv_empty; tvtc; } + + (* -------------------------------------------------------------------- *) + type closed = { tyuni : ty -> ty; tcuni : tcwitness -> tcwitness; } + + (* -------------------------------------------------------------------- *) + let close (uc : ucore) : closed = + let tymap = Hint.create 0 in + let tcmap = Hint.create 0 in + + let rec doit_ty t = + match t.ty_node with + | Tunivar i -> begin + match Hint.find_opt tymap (i :> int) with + | Some t -> t + | None -> begin + let t = + match UF.data i uc.uf with + | None -> tuni (UF.find i uc.uf) + | Some t -> doit_ty t + in + Hint.add tymap (i :> int) t; t + end + end + + | _ -> ty_map doit_ty t + + and doit_tc (tw : tcwitness) = + match tw with + | TCIUni (uid, lift) -> begin + match Hint.find_opt tcmap (uid :> int) with + | Some tw -> bump_lift lift tw + | None -> + let resolved = + match TcUni.Muid.find_opt uid uc.tcenv.resolution with + | None -> TCIUni (uid, []) + | Some (TCIUni (uid', _)) when TcUni.uid_equal uid uid' -> TCIUni (uid, []) + | Some tw -> doit_tc tw + in + Hint.add tcmap (uid :> int) resolved; + bump_lift lift resolved + end + + | TCIConcrete ({ etyargs; _ } as c) -> + let etyargs = + List.map + (fun (ty, tws) -> (doit_ty ty, List.map doit_tc tws)) + etyargs + in TCIConcrete { c with etyargs } + + | TCIAbstract { support = (`Var _ | `Abs _) } -> + tw + + in { tyuni = doit_ty; tcuni = doit_tc; } + + (* ------------------------------------------------------------------ *) + let subst_of_uf (uc : ucore) : ty TyUni.Muid.t = + let close = close uc in + + let dereference_tyuni (uid : tyuni) = + match close.tyuni (tuni uid) with + | { ty_node = Tunivar uid' } when TyUni.uid_equal uid uid' -> None + | ty -> Some ty in + + let bindings = + List.filter_map (fun uid -> + Option.map (fun ty -> (uid, ty)) (dereference_tyuni uid) + ) (UF.domain uc.uf) in + TyUni.Muid.of_list bindings + + (* -------------------------------------------------------------------- *) + let check_closed (uc : ucore) = + let tyvars = not (UF.closed uc.uf) in + let tcvars = not (tcenv_closed uc.tcenv) in + + if tyvars || tcvars then + raise (UninstanciateUni { tyvars; tcvars }) + + (* ------------------------------------------------------------------ *) + let fresh + ?(op_name : EcSymbols.symbol option) + ?(tcs : (typeclass * tcwitness option) list option) + ?(ty : ty option) + ({ uf; tcenv } as uc : ucore) + = + let uid = TyUni.unique () in -(* -------------------------------------------------------------------- *) -module UnifyCore = struct - let fresh ?ty uf = - let uid = EcUid.unique () in let uf = match ty with | Some { ty_node = Tunivar id } -> let uf = UF.set uid None uf in - fst (UF.union uid id uf) - | None | Some _ -> UF.set uid ty uf - in - (uf, tuni uid) -end + let ty, effects = UF.union uid id uf in + assert (List.is_empty effects); + ty -(* -------------------------------------------------------------------- *) -let unify_core (env : EcEnv.env) (uf : UF.t) pb = - let failure () = raise (UnificationFailure pb) in - - let uf = ref uf in - let pb = let x = Queue.create () in Queue.push pb x; x in - - let ocheck i t = - let i = UF.find i !uf in - let map = Hint.create 0 in + | (None | Some _) as ty -> + UF.set uid ty uf + in - let rec doit t = - match t.ty_node with - | Tunivar i' -> begin - let i' = UF.find i' !uf in + let ty = Option.value ~default:(tuni uid) (UF.data uid uf) in + + let tcenv, tws = + List.fold_left_map + (fun tcenv tcw -> create_tcproblem ?op_name tcenv ty tcw) + tcenv (Option.value ~default:[] tcs) in + + ({ uc with uf; tcenv; }, (tuni uid, tws)) + + (* ------------------------------------------------------------------ *) + let unify_core (env : EcEnv.env) (uc : ucore) (pb : problem) : ucore = + let failure () = raise (UnificationFailure pb) in + + let uc = ref uc in + let pb = let x = Queue.create () in Queue.push pb x; x in + + (* Seed the queue with every unresolved TC constraint. This catches + problems whose carrier type had no univar deps at creation time + (e.g. [Tvar 'a] for a TC-constrained type parameter), which would + otherwise sit in [problems] forever, never triggered via + [byunivar] eviction. Re-pushing already-deferred problems is + idempotent: the [`TcCtt] arm just re-adds them to [byunivar]. *) + TcUni.Muid.iter (fun uid (ty, tc, _op_name) -> + if not (TcUni.Muid.mem uid (!uc).tcenv.resolution) then + Queue.push (`TcCtt (uid, ty, tc)) pb + ) (!uc).tcenv.problems; + + let ocheck i t = + let i = UF.find i (!uc).uf in + let map = Hint.create 0 in + + let rec doit t = + match t.ty_node with + | Tunivar i' -> begin + let i' = UF.find i' (!uc).uf in match i' with | _ when i = i' -> true - | _ when Hint.mem map i' -> false + | _ when Hint.mem map (i' :> int) -> false | _ -> - match UF.data i' !uf with - | None -> Hint.add map i' (); false + match UF.data i' (!uc).uf with + | None -> Hint.add map (i' :> int) (); false | Some t -> match doit t with | true -> true - | false -> Hint.add map i' (); false - end + | false -> Hint.add map (i' :> int) (); false + end - | _ -> EcTypes.ty_sub_exists doit t + | _ -> ty_sub_exists doit t + in + doit t in - doit t - in - - let setvar i t = - let (ti, effects) = UFArgs.D.union (UF.data i !uf) (Some t) in - if odfl false (ti |> omap (ocheck i)) then failure (); - List.iter (Queue.push^~ pb) effects; - uf := UF.set i ti !uf - - and getvar t = - match t.ty_node with - | Tunivar i -> odfl t (UF.data i !uf) - | _ -> t - in - - let doit () = - while not (Queue.is_empty pb) do - match Queue.pop pb with - | `TyUni (t1, t2) -> begin - let (t1, t2) = (getvar t1, getvar t2) in - - match ty_equal t1 t2 with - | true -> () - | false -> begin - match t1.ty_node, t2.ty_node with - | Tunivar id1, Tunivar id2 -> begin - if not (uid_equal id1 id2) then - let effects = reffold (swap -| UF.union id1 id2) uf in - List.iter (Queue.push^~ pb) effects - end - - | Tunivar id, _ -> setvar id t2 - | _, Tunivar id -> setvar id t1 - - | Ttuple lt1, Ttuple lt2 -> - if List.length lt1 <> List.length lt2 then failure (); - List.iter2 (fun t1 t2 -> Queue.push (`TyUni (t1, t2)) pb) lt1 lt2 - - | Tfun (t1, t2), Tfun (t1', t2') -> - Queue.push (`TyUni (t1, t1')) pb; - Queue.push (`TyUni (t2, t2')) pb + let setvar (i : tyuni) (t : ty) = + let (ti, effects) = + UFArgs.D.union (UF.data i (!uc).uf) (Some t) + in + if odfl false (ti |> omap (ocheck i)) then failure (); + List.iter (Queue.push^~ pb) effects; + + begin + match TyUni.Muid.find i (!uc).tcenv.byunivar with + | tcpbs -> + uc := { !uc with tcenv = { (!uc).tcenv with + byunivar = TyUni.Muid.remove i (!uc).tcenv.byunivar + } }; + let tcpbs = TcUni.Suid.elements tcpbs in + let tcpbs = List.map (fun uid -> + let pb = TcUni.Muid.find uid (!uc).tcenv.problems in + (uid, pb) + ) tcpbs in + List.iter (fun (uid, (ty, tc, _op)) -> Queue.push (`TcCtt (uid, ty, tc)) pb) tcpbs + + | exception Not_found -> () + end; + + uc := { !uc with uf = UF.set i ti (!uc).uf } + + and getvar t = + match t.ty_node with + | Tunivar i -> odfl t (UF.data i (!uc).uf) + | _ -> t + in - | Tconstr (p1, lt1), Tconstr (p2, lt2) when EcPath.p_equal p1 p2 -> + let doit () = + while not (Queue.is_empty pb) do + match Queue.pop pb with + | `TyUni (t1, t2) -> begin + let (t1, t2) = (getvar t1, getvar t2) in + + match ty_equal t1 t2 with + | true -> () + | false -> begin + match t1.ty_node, t2.ty_node with + | Tunivar id1, Tunivar id2 -> begin + if not (TyUni.uid_equal id1 id2) then begin + let effects = + reffold (fun uc -> + let uf, effects = UF.union id1 id2 uc.uf in + effects, { uc with uf } + ) uc in + List.iter (Queue.push^~ pb) effects; + (* Merge byunivar entries onto the new representative. *) + let repr = UF.find id1 (!uc).uf in + let merge id = + if not (TyUni.uid_equal id repr) then + match TyUni.Muid.find_opt id (!uc).tcenv.byunivar with + | None -> () + | Some pbs -> + uc := { !uc with tcenv = { (!uc).tcenv with byunivar = + let bv = TyUni.Muid.remove id (!uc).tcenv.byunivar in + TyUni.Muid.change (fun map -> + let map = Option.value ~default:TcUni.Suid.empty map in + Some (TcUni.Suid.union map pbs) + ) repr bv + } } + in merge id1; merge id2 + end + end + + | Tunivar id, _ -> setvar id t2 + | _, Tunivar id -> setvar id t1 + + | Ttuple lt1, Ttuple lt2 -> + if List.length lt1 <> List.length lt2 then failure (); + List.iter2 (fun t1 t2 -> Queue.push (`TyUni (t1, t2)) pb) lt1 lt2 + + | Tfun (t1, t2), Tfun (t1', t2') -> + Queue.push (`TyUni (t1, t1')) pb; + Queue.push (`TyUni (t2, t2')) pb + + | Tconstr (p1, lt1), Tconstr (p2, lt2) when EcPath.p_equal p1 p2 -> if List.length lt1 <> List.length lt2 then failure (); - List.iter2 (fun t1 t2 -> Queue.push (`TyUni (t1, t2)) pb) lt1 lt2 - - | Tconstr (p, lt), _ when EcEnv.Ty.defined p env -> - Queue.push (`TyUni (EcEnv.Ty.unfold p lt env, t2)) pb - - | _, Tconstr (p, lt) when EcEnv.Ty.defined p env -> - Queue.push (`TyUni (t1, EcEnv.Ty.unfold p lt env)) pb - | _, _ -> failure () - end - end - done - in - doit (); !uf + let ty1, tws1 = List.split lt1 in + let ty2, tws2 = List.split lt2 in -(* -------------------------------------------------------------------- *) -let close (uf : UF.t) = - let map = Hint.create 0 in + List.iter2 (fun t1 t2 -> Queue.push (`TyUni (t1, t2)) pb) ty1 ty2; - let rec doit t = - match t.ty_node with - | Tunivar i -> begin - match Hint.find_opt map i with - | Some t -> t - | None -> begin - let t = - match UF.data i uf with - | None -> tuni (UF.find i uf) - | Some t -> doit t - in - Hint.add map i t; t - end - end + List.iter2 (fun tw1 tw2 -> + if List.length tw1 <> List.length tw2 then failure (); + List.iter2 (fun w1 w2 -> Queue.push (`TcTw (w1, w2)) pb) tw1 tw2 + ) tws1 tws2 - | _ -> ty_map doit t - in - fun t -> doit t + | Tconstr (p, lt), _ when EcEnv.Ty.defined p env -> + Queue.push (`TyUni (EcEnv.Ty.unfold p lt env, t2)) pb + | _, Tconstr (p, lt) when EcEnv.Ty.defined p env -> + Queue.push (`TyUni (t1, EcEnv.Ty.unfold p lt env)) pb + | _, _ -> failure () + end + end -(* -------------------------------------------------------------------- *) -let subst_of_uf (uf : UF.t) = - let close = close uf in - let uids = UF.domain uf in - List.fold_left - (fun m uid -> - match close (tuni uid) with - | { ty_node = Tunivar uid' } when uid_equal uid uid' -> m - | t -> Muid.add uid t m - ) - Muid.empty - uids + | `TcCtt (uid, ty, tc) when + TcUni.Muid.mem uid (!uc).tcenv.resolution -> + (* [uid] was already pinned (e.g. by a prior [TcTw] equation + from the surrounding goal). Honor that binding rather than + re-running strategies, which could produce a different + witness on ambiguous instance lookups. *) + ignore (ty, tc) + + | `TcCtt (uid, ty, tc) -> + (* Op name attached to this problem at creation time, used + below to filter parent-DAG paths whose cumulative renaming + would clobber that name. None for TC problems not tied to + a specific named op. *) + let pb_op : EcSymbols.symbol option = + match TcUni.Muid.find_opt uid (!uc).tcenv.problems with + | Some (_, _, op) -> op + | None -> None in + (* See doc/typeclasses-inference.md for the strategy framework + and the catalog of inference modes this resolver covers. *) + let deps = ref TyUni.Suid.empty in + + let rec check_ty (ty : ty) : ty = + match ty.ty_node with + | Tunivar tyuvar -> begin + match UF.data tyuvar (!uc).uf with + | None -> + deps := TyUni.Suid.add tyuvar !deps; + ty + | Some ty -> + check_ty ty + end + | _ -> ty_map check_ty ty in + + let rec check_tcw (tcw : tcwitness) : tcwitness = + match tcw with + | TCIUni (tcuid, lift) -> begin + match TcUni.Muid.find_opt tcuid (!uc).tcenv.resolution with + | Some (TCIUni (tcuid', _)) when TcUni.uid_equal tcuid tcuid' -> tcw + | Some tcw' -> bump_lift lift (check_tcw tcw') + | None -> tcw + end + | TCIConcrete cw -> + let etyargs = List.map check_etyarg cw.etyargs in + TCIConcrete { cw with etyargs } + | TCIAbstract _ -> tcw + and check_etyarg ((ty, tcws) : etyarg) = + (check_ty ty, List.map check_tcw tcws) in + + let tc = + { tc with tc_args = List.map check_etyarg tc.tc_args } in + + let ty = check_ty ty in + let deps = !deps in + + (* ---- Helpers shared across strategies ---- *) + (* [tvtc] stores TC constraints as they were typed at tparam + declaration; the args may still mention Tunivars that were + since merged in [uf]. Dereference via [check_etyarg] before + structural comparison. *) + let deref_tc (tc' : typeclass) = + { tc' with tc_args = List.map check_etyarg tc'.tc_args } in + (* Compare on type arguments only; the corresponding tcwitnesses + are determined by [(carrier, type args)] and may legitimately + differ in form (e.g. unresolved TCIUni vs concrete) while + still picking out the same TC. *) + let eq_tc (tc' : typeclass) = + let tc' = deref_tc tc' in + EcPath.p_equal tc.tc_name tc'.tc_name + && List.length tc.tc_args = List.length tc'.tc_args + && List.for_all2 + (fun (a, _) (b, _) -> EcCoreEqTest.for_type env a b) + tc.tc_args tc'.tc_args in + + (* Enumerate all parent-DAG paths from [tc'] to [tc]. Each + returned entry is a list of parent-edge indices paired + with the cumulative ancestor→child op renaming along the + walk. [[]] means [tc' = tc] directly. With single-parent + inheritance the path is always all-zeros; with + multi-parent (factory) classes the path encodes which + parent edge is taken at each step. + + The renaming is needed downstream to filter paths by + op-name preservation: when querying op [n] via this TC, + only paths whose cumulative renaming preserves [n] can + expose it under the same name at the carrier site. *) + let with_lift tc' + : (int list * (EcSymbols.symbol * EcSymbols.symbol) list) list + = + let rec walk tc ren path acc = + if eq_tc tc then (List.rev path, ren) :: acc + else + let decl = EcEnv.TypeClass.by_path tc.tc_name env in + let subst = + List.fold_left2 + (fun s (a, _) etyarg -> Mid.add a etyarg s) + Mid.empty decl.tc_tparams tc.tc_args in + List.fold_lefti + (fun acc i (parent, p_ren) -> + let parent = EcCoreSubst.Tvar.subst_tc subst parent in + let ren' = + EcTypeClass.compose_renaming ~outer:p_ren ~inner:ren in + walk parent ren' (i :: path) acc) + acc decl.tc_prts + in walk tc' [] [] [] in + (* Returns all valid [(offset, path, renaming)] matches + across [tcs], one per (offset, parent-path) pair that + reaches [tc]. The renaming is the cumulative + ancestor→child op renaming for that path. *) + let match_tc_offsets_all (tcs : typeclass list) + : (int * int list * (EcSymbols.symbol * EcSymbols.symbol) list) list + = + List.concat (List.mapi + (fun i tc' -> + List.map (fun (p, ren) -> (i, p, ren)) (with_lift tc')) + tcs) in + (* Op-name-aware variant: when [pb_op] is set, drop paths + whose cumulative renaming clobbers the op name. *) + let match_tc_offsets (tcs : typeclass list) = + let cands = match_tc_offsets_all tcs in + match pb_op with + | None -> cands + | Some n -> + List.filter + (fun (_, _, ren) -> EcTypeClass.op_preserved ren n) + cands in + let rens_equal r1 r2 = + List.length r1 = List.length r2 + && List.for_all (fun (a, b) -> + match List.assoc_opt a r2 with + | Some b' -> b = b' + | None -> false) r1 in + let match_tc_offset (tcs : typeclass list) + : (int * int list * (EcSymbols.symbol * EcSymbols.symbol) list) option + = + (* Multi-parent inheritance can yield several parent-DAG + paths reaching the same target TC. When all such paths + carry the same cumulative renaming, they're + semantically interchangeable, so picking the canonical + (BFS-first) encoding is safe. Only ambiguity-preserving + (different renamings) genuinely blocks resolution. *) + match match_tc_offsets tcs with + | [] -> None + | m :: rest -> + let (_, _, ren_m) = m in + if not (List.for_all (fun (_, _, r) -> rens_equal r ren_m) rest) + then None + else + match EcTcCanonical.canonical_path env tcs tc.tc_name ren_m with + | Some (off, lift) -> Some (off, lift, ren_m) + | None -> Some m in + + (* ---- Strategies (catalog modes) ---- + Each strategy returns [Some witness] when it resolves, or + [None] when it does not apply / cannot decide. The dispatcher + below tries them in priority order. *) + + (* Mode #5: carrier is [Tvar a] with a in [tvtc]. *) + let strat_tvar_via_tvtc () : tcwitness option = + match ty.ty_node with + | Tvar a -> + let tcs = ofdfl failure (Mid.find_opt a (!uc).tvtc) in + Option.map + (fun (offset, lift, _ren) -> + TCIAbstract { support = `Var a; offset; lift }) + (match_tc_offset tcs) + | _ -> None in + + (* Mode #6: carrier is [Tconstr p] with [p] an abstract decl. *) + let strat_abs_via_decl () : tcwitness option = + match ty.ty_node with + | Tconstr (p, _) -> begin + match EcEnv.Ty.by_path_opt p env with + | Some { tyd_type = `Abstract tcs; _ } -> + Option.map + (fun (offset, lift, _ren) -> + TCIAbstract { support = `Abs p; offset; lift }) + (match_tc_offset tcs) + | _ -> None + end + | _ -> None in + + (* Modes #1, #2: carrier is ground; query the instance database. *) + let strat_infer_by_carrier () : tcwitness option = + EcTypeClass.infer env ty tc in + (* Ambiguity check: when multiple resolutions match, defer + so that later [TcTw] equations from the surrounding goal + can pin the univar to the correct one. + + Two sources of ambiguity: + - Concrete carriers: [infer_all] returns multiple + synthesised instances (multi-flavor inheritance). + - Tvar / abstract-type carriers: [match_tc_offsets] + returns multiple (offset, path) pairs (multiple parent + paths through the DAG to the same target TC). *) + (* Multiple paths with identical renamings are not + genuinely ambiguous — [match_tc_offset] picks one. *) + let paths_genuinely_ambiguous tcs = + match match_tc_offsets tcs with + | [] | [_] -> false + | m :: rest -> + let (_, _, ren_m) = m in + not (List.for_all (fun (_, _, r) -> rens_equal r ren_m) rest) in + let strat_carrier_is_ambiguous () : bool = + match ty.ty_node with + | Tvar a -> begin + match Mid.find_opt a (!uc).tvtc with + | None -> false + | Some tcs -> paths_genuinely_ambiguous tcs + end + | Tconstr (p, _) -> begin + let by_decl = + match EcEnv.Ty.by_path_opt p env with + | Some { tyd_type = `Abstract tcs; _ } -> + paths_genuinely_ambiguous tcs + | _ -> false in + by_decl + || List.length (EcTypeClass.infer_all env ty tc) > 1 + end + | _ -> false in + + (* Univars appearing in [tc.tc_args] (types and witnesses). + Used both for the Mode-#3 strategy gating and to register + extra parking edges so the problem re-fires when one of + them is resolved later. *) + let etyarg_univars (a, ws) = + let from_ty = Tuni.univars a in + List.fold_left (fun s w -> + TyUni.Suid.union s + (tcw_fold + (fun s t -> TyUni.Suid.union s (Tuni.univars t)) + TyUni.Suid.empty w)) + from_ty ws in + let arg_deps = + List.fold_left (fun s a -> TyUni.Suid.union s (etyarg_univars a)) + TyUni.Suid.empty tc.tc_args in + + (* Mode #3: carrier is a univar; identify a unique matching + instance by [tc.tc_args] (Tunivars on the goal side act + as wildcards), then push a [`TyUni (ty, tci_type)] + equation. The carrier resolution will then re-fire this + TcCtt under Mode #1 and produce the concrete witness. + For TCs with no args (e.g. [addmonoid]), all instances + match trivially, so by-args contributes no signal — skip. *) + let strat_infer_by_args () : tcwitness option = + if List.is_empty tc.tc_args then None else + let cands = EcTypeClass.candidates_by_args env tc in + (* Multiple matches: check whether they agree on the + carrier ([tci_type]). If they do, any of them works; if + they don't, the goal is genuinely ambiguous and no + further unification can decide between them. *) + if List.length cands >= 2 then begin + let carriers = + List.map (fun (_, tci, _) -> tci.tci_type) cands in + let same = + match carriers with + | [] | [_] -> true + | t :: rest -> + List.for_all (fun t' -> + EcCoreEqTest.for_type env t t') rest in + if not same then begin + let paths = List.filter_map (fun (p, _, _) -> p) cands in + raise (AmbiguousTcInstance (tc, paths)) + end + end; + match cands with + | [(Some _, tci, _map)] -> begin + (* Recover the candidate's [tgp.tc_args] (the patterns). *) + let tgargs = + match tci.tci_instance with + | `General (tgp, _) -> tgp.tc_args + | _ -> assert false in + (* Open the candidate's tparams as fresh univars. *) + let inst_subst = + List.fold_left (fun subst (a, _) -> + let (uc', (fresh_ty, _)) = fresh (!uc) in + uc := uc' ; + Mid.add a (fresh_ty, []) subst + ) Mid.empty tci.tci_params in + let tgargs = + List.map (EcCoreSubst.Tvar.subst_etyarg inst_subst) tgargs in + let inst_carrier = + EcCoreSubst.Tvar.subst inst_subst tci.tci_type in + (* Push TyUni equations: each goal arg unifies with the + candidate's substituted arg, and the carrier with + [tci_type]. The unifier binds goal Tunivars to the + corresponding patterns and triggers Mode #1 re-firing + once the carrier is concrete. *) + List.iter2 (fun (gty, _) (pty, _) -> + Queue.push (`TyUni (gty, pty)) pb) + tc.tc_args tgargs; + Queue.push (`TyUni (ty, inst_carrier)) pb; + None (* Defer witness construction; Mode #1 will fire. *) + end + | _ -> None in + + (* ---- Dispatch ---- *) + if TyUni.Suid.is_empty deps then begin + let ambiguous = + match ty.ty_node with + | Tvar _ | Tconstr _ -> strat_carrier_is_ambiguous () + | _ -> false in + let resolution_opt = + if ambiguous then None + else + match ty.ty_node with + | Tvar _ -> + strat_tvar_via_tvtc () + | Tconstr _ when Option.is_some (strat_abs_via_decl ()) -> + strat_abs_via_decl () + | _ -> + strat_infer_by_carrier () + in + match resolution_opt with + | Some resolution -> + uc := { !uc with tcenv = { (!uc).tcenv with resolution = + TcUni.Muid.add uid resolution (!uc).tcenv.resolution + } } + | None when not (TyUni.Suid.is_empty arg_deps) -> + (* Carrier is concrete but TC arg univars still pending; + park on those so we re-fire when they bind. *) + TyUni.Suid.iter (fun tyvar -> + uc := { !uc with tcenv = { (!uc).tcenv with byunivar = + TyUni.Muid.change (fun map -> + let map = Option.value ~default:TcUni.Suid.empty map in + Some (TcUni.Suid.add uid map) + ) tyvar (!uc).tcenv.byunivar + } } + ) arg_deps + | None when ambiguous -> + (* Defer: hold this TcCtt unresolved. A later [TcTw] + equation from the surrounding goal will pin [uid] + to the goal's specific witness via [bind_uni]. *) + () + | None -> failure () + end else begin + match strat_infer_by_args () with + | Some witness -> + uc := { !uc with tcenv = { (!uc).tcenv with resolution = + TcUni.Muid.add uid witness (!uc).tcenv.resolution + } } + | None -> + (* Mode #4: carrier still has univars; park on each. + Also park on [arg_deps] so a later binding of a + typeclass argument re-fires Mode #3. *) + TyUni.Suid.iter (fun tyvar -> + uc := { !uc with tcenv = { (!uc).tcenv with byunivar = + TyUni.Muid.change (fun map -> + let map = Option.value ~default:TcUni.Suid.empty map in + Some (TcUni.Suid.add uid map) + ) tyvar (!uc).tcenv.byunivar + } } + ) (TyUni.Suid.union deps arg_deps) + end + + | `TcTw (w1, w2) -> + (* Resolve a [TCIUni (u, l)] one level: if [u] has a known + resolution [r], return [bump_lift l r]; otherwise leave the + reference intact. This is local to the current unification + attempt. *) + let resolve_uni = function + | TCIUni (uid, lift) -> begin + match TcUni.Muid.find_opt uid (!uc).tcenv.resolution with + | Some w -> bump_lift lift w + | None -> TCIUni (uid, lift) + end + | w -> w in + + let w1 = resolve_uni w1 in + let w2 = resolve_uni w2 in + + let bind_uni uid lift target = + (* We want [bump_lift lift R = target] where [R] is the + resolution of [uid] (a witness for [uid]'s carrier-type + at [uid]'s TC class). With canonical-encoded paths + everywhere (Stage 2 Phase A/C), [target]'s path ends + with [lift] when reachable via [uid], so structural + suffix-strip recovers [R]. *) + let strip_suffix sfx l = + match sfx, List.rev l with + | [], _ -> Some l + | _, [] -> None + | _, _ -> + let sfx_rev = List.rev sfx in + let l_rev = List.rev l in + let rec eq_pref a b = + match a, b with + | [], _ -> Some (List.rev b) + | _, [] -> None + | x :: xs, y :: ys when x = y -> eq_pref xs ys + | _ -> None + in eq_pref sfx_rev l_rev in + let strip_lift sfx w = + match w with + | TCIUni (u, l) -> + Option.map (fun l' -> TCIUni (u, l')) (strip_suffix sfx l) + | TCIConcrete c -> + Option.map (fun l' -> TCIConcrete { c with lift = l' }) + (strip_suffix sfx c.lift) + | TCIAbstract a -> + Option.map (fun l' -> TCIAbstract { a with lift = l' }) + (strip_suffix sfx a.lift) in + match strip_lift lift target with + | None -> failure () + | Some r -> + uc := { !uc with tcenv = { (!uc).tcenv with resolution = + TcUni.Muid.add uid r (!uc).tcenv.resolution + } } in + + begin match w1, w2 with + | TCIUni (u1, l1), TCIUni (u2, l2) when TcUni.uid_equal u1 u2 -> + if l1 <> l2 then failure () + + | TCIUni (uid, lift), w + | w, TCIUni (uid, lift) -> + bind_uni uid lift w + + | _, _ -> + let w1 = EcTcCanonical.canonicalise_witness env w1 in + let w2 = EcTcCanonical.canonicalise_witness env w2 in + if not (EcAst.tcw_equal w1 w2) then failure () + end + done + in + doit (); !uc +end (* -------------------------------------------------------------------- *) type unienv_r = { - ue_uf : UF.t; + ue_uc : Unify.ucore; ue_named : EcIdent.t Mstr.t; ue_decl : EcIdent.t list; ue_closed : bool; @@ -208,12 +817,19 @@ type unienv_r = { type unienv = unienv_r ref +type petyarg = ty option * tcwitness option list option + type tvar_inst = -| TVIunamed of ty list -| TVInamed of (EcSymbols.symbol * ty) list +| TVIunamed of petyarg list +| TVInamed of (EcSymbols.symbol * petyarg) list type tvi = tvar_inst option -type uidmap = uid -> ty option + +let tvi_unamed (ety : etyarg list) : tvar_inst = + TVIunamed (List.map + (fun (ty, tcw) -> Some ty, Some (List.map Option.some tcw)) + ety + ) module UniEnv = struct let copy (ue : unienv) : unienv = @@ -234,119 +850,350 @@ module UniEnv = struct }; id end - let create (vd : EcIdent.t list option) = - let ue = { - ue_uf = UF.initial; - ue_named = Mstr.empty; - ue_decl = []; - ue_closed = false; - } in - + let create (vd : (EcIdent.t * typeclass list) list option) : unienv = let ue = match vd with - | None -> ue + | None -> + { ue_uc = Unify.initial_ucore () + ; ue_named = Mstr.empty + ; ue_decl = [] + ; ue_closed = false + } + | Some vd -> - let vdmap = List.map (fun x -> (EcIdent.name x, x)) vd in - { ue with - ue_named = Mstr.of_list vdmap; - ue_decl = List.rev vd; - ue_closed = true; } - in - ref ue + let vdmap = List.map (fun (x, _) -> (EcIdent.name x, x)) vd in + let tvtc = Mid.of_list vd in + { ue_uc = Unify.initial_ucore ~tvtc () + ; ue_named = Mstr.of_list vdmap + ; ue_decl = List.rev_map fst vd + ; ue_closed = true; + } + in ref ue + + let push ((x, tc) : ident * typeclass list) (ue : unienv) = + assert (not (Mstr.mem (EcIdent.name x) (!ue).ue_named)); + assert ((!ue).ue_closed); + + ue := + { ue_uc = { (!ue).ue_uc with tvtc = Mid.add x tc (!ue).ue_uc.tvtc } + ; ue_named = Mstr.add (EcIdent.name x) x (!ue).ue_named + ; ue_decl = x :: (!ue).ue_decl + ; ue_closed = true } + + let xfresh + ?(op_name : EcSymbols.symbol option) + ?(tcs : (typeclass * tcwitness option) list option) + ?(ty : ty option) + (ue : unienv) + = + let (uc, tytw) = Unify.fresh ?op_name ?tcs ?ty (!ue).ue_uc in + ue := { !ue with ue_uc = uc }; tytw let fresh ?(ty : ty option) (ue : unienv) = - let (uf, uid) = UnifyCore.fresh ?ty (!ue).ue_uf in - ue := { !ue with ue_uf = uf }; uid - - let opentvi (ue : unienv) (params : ty_params) (tvi : tvar_inst option) = - match tvi with - | None -> - List.fold_left - (fun s v -> Mid.add v (fresh ue) s) - Mid.empty params - - | Some (TVIunamed lt) -> - List.fold_left2 - (fun s v ty -> Mid.add v (fresh ~ty ue) s) - Mid.empty params lt - - | Some (TVInamed lt) -> - let for1 s v = - let t = - try fresh ~ty:(List.assoc (EcIdent.name v) lt) ue - with Not_found -> fresh ue - in - Mid.add v t s - in - List.fold_left for1 Mid.empty params + let (uc, (ty, tw)) = Unify.fresh ?ty (!ue).ue_uc in + assert (List.is_empty tw); + ue := { !ue with ue_uc = uc }; ty + + type opened = { + subst : etyarg Mid.t; + params : (ty * typeclass list) list; + args : etyarg list; + } + + let subst_tv (subst : etyarg Mid.t) (params : ty_params) = + List.map (fun (tv, tcs) -> + let tv = Tvar.subst subst (tvar tv) in + let tcs = + List.map + (fun tc -> + let tc_args = + List.map (Tvar.subst_etyarg subst) tc.tc_args + in { tc with tc_args }) + tcs + in (tv, tcs)) params + + let opentvi + ?(op_name : EcSymbols.symbol option) + (ue : unienv) (params : ty_params) (tvi : tvi) : opened = + let tvi = + match tvi with + | None -> + List.map (fun (v, tcs) -> + (v, (None, List.map (fun x -> (x, None)) tcs)) + ) params + + | Some (TVIunamed lt) -> + let combine (v, tc) (ty, tcw) = + let tctcw = + match tcw with + | None -> + List.map (fun tc -> (tc, None)) tc + | Some tcw -> + List.combine tc tcw + in (v, (ty, tctcw)) in + + List.map2 combine params lt + + | Some (TVInamed lt) -> + List.map (fun (v, tc) -> + let ty, tcw = + List.assoc_opt (EcIdent.name v) lt + |> Option.value ~default:(None, None) in + + let tcw = + match tcw with + | None -> + List.map (fun _ -> None) tc + | Some tcw -> + tcw in + + (v, (ty, List.map2 (fun x y -> (x, y)) tc tcw)) + ) params + in + + let subst = + List.fold_left (fun s (v, (ty, tcws)) -> + let tcs = + let for1 (tc, tcw) = + let tc = + { tc_name = tc.tc_name; + tc_args = List.map (Tvar.subst_etyarg s) tc.tc_args } in + (tc, tcw) + in List.map for1 tcws + in Mid.add v (xfresh ?op_name ?ty ~tcs ue) s + ) Mid.empty tvi in - let subst_tv (subst : ty -> ty) (params : ty_params) = - List.map (fun tv -> subst (tvar tv)) params + let args = List.map (fun (x, _) -> oget (Mid.find_opt x subst)) params in + let params = subst_tv subst params in - let openty_r (ue : unienv) (params : ty_params) (tvi : tvar_inst option) = - let subst = f_subst_init ~tv:(opentvi ue params tvi) () in - (subst, subst_tv (ty_subst subst) params) + { subst; args; params; } - let opentys (ue : unienv) (params : ty_params) (tvi : tvar_inst option) (tys : ty list) = - let (subst, tvs) = openty_r ue params tvi in - (List.map (ty_subst subst) tys, tvs) + let opentys (ue : unienv) (params : ty_params) (tvi : tvi) (tys : ty list) = + let opened = opentvi ue params tvi in + let tys = List.map (Tvar.subst opened.subst) tys in + tys, opened - let openty (ue : unienv) (params : ty_params) (tvi : tvar_inst option) (ty : ty)= - let (subst, tvs) = openty_r ue params tvi in - (ty_subst subst ty, tvs) + let openty (ue : unienv) (params : ty_params) (tvi : tvi) (ty : ty) = + let opened = opentvi ue params tvi in + Tvar.subst opened.subst ty, opened let repr (ue : unienv) (t : ty) : ty = match t.ty_node with - | Tunivar id -> odfl t (UF.data id (!ue).ue_uf) + | Tunivar id -> odfl t (Unify.UF.data id (!ue).ue_uc.uf) | _ -> t + let xclosed (ue : unienv) = + try Unify.check_closed (!ue).ue_uc; None + with UninstanciateUni infos -> Some infos + let closed (ue : unienv) = - UF.closed (!ue).ue_uf + Option.is_none (xclosed ue) - let close (ue : unienv) = - if not (closed ue) then raise UninstantiateUni; - (subst_of_uf (!ue).ue_uf) + let assubst (ue : unienv) : ty TyUni.Muid.t = + Unify.subst_of_uf (!ue).ue_uc - let assubst (ue : unienv) = - subst_of_uf (!ue).ue_uf + let tw_assubst (ue : unienv) : tcwitness TcUni.Muid.t = + (!ue).ue_uc.tcenv.resolution - let tparams (ue : unienv) : ty_params = - List.rev (!ue).ue_decl + let close (ue : unienv) = + Unify.check_closed (!ue).ue_uc; + assubst ue + + (* Drain the pending TcCtt queue: invokes [Unify.unify_core] on a + trivially-true [TyUni] problem, which causes the unifier to first + re-process every parked [TcCtt] in [tcenv.problems]. After this, + any constraint that the strategies (Mode #1 .. #6) can resolve is + committed to [tcenv.resolution]. Constraints that defer (ambiguous + or carrier-with-univars) stay parked. *) + let flush_tc_problems (env : EcEnv.env) (ue : unienv) : unit = + if not (TcUni.Muid.is_empty (!ue).ue_uc.tcenv.problems) then + try + let trig = tunit in + let uc = Unify.unify_core env (!ue).ue_uc (`TyUni (trig, trig)) in + ue := { !ue with ue_uc = uc } + with UnificationFailure _ -> () + + let tparams (ue : unienv) = + let close = Unify.close (!ue).ue_uc in + let deref_tc (tc : typeclass) : typeclass = + let tc_args = + List.map + (fun (t, ws) -> (close.tyuni t, List.map close.tcuni ws)) + tc.tc_args + in { tc with tc_args } + in + let fortv x = + let tvtc = odfl [] (Mid.find_opt x (!ue).ue_uc.tvtc) in + List.map deref_tc tvtc in + List.map (fun x -> (x, fortv x)) (List.rev (!ue).ue_decl) end +(* -------------------------------------------------------------------- *) +let unify_core (env : EcEnv.env) (ue : unienv) (pb : problem) = + let uc = Unify.unify_core env (!ue).ue_uc pb in + ue := { !ue with ue_uc = uc; } + (* -------------------------------------------------------------------- *) let unify (env : EcEnv.env) (ue : unienv) (t1 : ty) (t2 : ty) = - let uf = unify_core env (!ue).ue_uf (`TyUni (t1, t2)) in - ue := { !ue with ue_uf = uf; } + unify_core env ue (`TyUni (t1, t2)) (* -------------------------------------------------------------------- *) -let tfun_expected ue ?retty psig = - let retty = ofdfl (fun () -> UniEnv.fresh ue) retty in - EcTypes.toarrow psig retty +let unify_tcw (env : EcEnv.env) (ue : unienv) (w1 : tcwitness) (w2 : tcwitness) = + unify_core env ue (`TcTw (w1, w2)) (* -------------------------------------------------------------------- *) -type sbody = ((EcIdent.t * ty) list * expr) Lazy.t +let unify_etyarg (env : EcEnv.env) (ue : unienv) (e1 : etyarg) (e2 : etyarg) = + let (t1, ws1) = e1 and (t2, ws2) = e2 in + unify env ue t1 t2; + if List.length ws1 <> List.length ws2 then + raise (UnificationFailure (`TyUni (t1, t2))); + List.iter2 (unify_tcw env ue) ws1 ws2 (* -------------------------------------------------------------------- *) -type select_result = (EcPath.path * ty list) * ty * unienv * sbody option +(* When typing an op application like [(+)<:comring>], the witness for + the op's [<: monoid] tparam may be ambiguous: the carrier [comring] + reaches [monoid] via two parent walks (via [addgroup] and via + [mulmonoid]). The TC inference framework is op-name-agnostic, so + it sees both paths as candidates. + + But the parent-edge renamings disambiguate: only paths whose + cumulative ancestor→child renaming preserves the queried op name + actually expose that op under the same name at the carrier site. + + This helper, called right after [opentvi] at op-typing sites, walks + each fresh witness univar and binds it to the unique [TCIAbstract] + for the op-name-preserving path, when one exists. If zero or + multiple paths preserve the name, the witness is left as a univar + and existing strategies handle it as before. *) +let disambiguate_op_witnesses + (env : EcEnv.env) + (ue : unienv) + (op_name : EcSymbols.symbol) + (params : (ty * typeclass list) list) + (args : etyarg list) + : unit += + let close = Unify.close (!ue).ue_uc in + + (* Path enumeration with renaming, top-level analogue of the + [with_lift] inside [unify_core]. *) + let with_lift_for (carrier_tcs : typeclass list) (target : typeclass) + : (int * int list * (EcSymbols.symbol * EcSymbols.symbol) list) list + = + let target = + let tc_args = + List.map + (fun (t, ws) -> (close.tyuni t, List.map close.tcuni ws)) + target.tc_args + in { target with tc_args } in + let eq_tc (tc' : typeclass) = + let tc' = + let tc_args = + List.map + (fun (t, ws) -> (close.tyuni t, List.map close.tcuni ws)) + tc'.tc_args + in { tc' with tc_args } in + EcPath.p_equal target.tc_name tc'.tc_name + && List.length target.tc_args = List.length tc'.tc_args + && List.for_all2 + (fun (a, _) (b, _) -> EcCoreEqTest.for_type env a b) + target.tc_args tc'.tc_args in + let rec walk tc ren path acc = + if eq_tc tc then (List.rev path, ren) :: acc + else + let decl = EcEnv.TypeClass.by_path tc.tc_name env in + let subst = + List.fold_left2 + (fun s (a, _) etyarg -> Mid.add a etyarg s) + Mid.empty decl.tc_tparams tc.tc_args in + List.fold_lefti + (fun acc i (parent, p_ren) -> + let parent = EcCoreSubst.Tvar.subst_tc subst parent in + let ren' = + EcTypeClass.compose_renaming ~outer:p_ren ~inner:ren in + walk parent ren' (i :: path) acc) + acc decl.tc_prts in + List.concat (List.mapi + (fun i tc' -> + List.map (fun (p, ren) -> (i, p, ren)) (walk tc' [] [] [])) + carrier_tcs) in + + let try_pin + ~(build : int -> int list -> tcwitness) + (carrier_tcs : typeclass list) + (target : typeclass) + (w : tcwitness) + : unit + = + match w with + | TCIUni _ -> begin + let candidates = with_lift_for carrier_tcs target in + if List.length candidates < 2 then () + else + let preserved = + List.filter + (fun (_, _, ren) -> EcTypeClass.op_preserved ren op_name) + candidates in + match preserved with + | [(offset, lift, _)] -> + (try unify_tcw env ue w (build offset lift) + with UnificationFailure _ -> ()) + | _ -> () + end + | _ -> () in + + List.iter2 (fun (carrier_ty, tcs) (_, ws) -> + let carrier_ty = close.tyuni carrier_ty in + if List.length tcs <> List.length ws then () else + match carrier_ty.ty_node with + | Tvar a -> begin + match Mid.find_opt a (!ue).ue_uc.tvtc with + | None -> () + | Some carrier_tcs -> + let build offset lift = + TCIAbstract { support = `Var a; offset; lift } in + List.iter2 (try_pin ~build carrier_tcs) tcs ws + end + | Tconstr (p, _) -> begin + match EcEnv.Ty.by_path_opt p env with + | Some { tyd_type = `Abstract carrier_tcs; _ } -> + let build offset lift = + TCIAbstract { support = `Abs p; offset; lift } in + List.iter2 (try_pin ~build carrier_tcs) tcs ws + | _ -> () + end + | _ -> () + ) params args + +(* -------------------------------------------------------------------- *) +let tfun_expected (ue : unienv) ?retty (psig : ty list) = + let ret = match retty with Some t -> t | None -> UniEnv.fresh ue in + toarrow psig ret + +(* -------------------------------------------------------------------- *) +type sbody = ((EcIdent.t * ty) list * expr) Lazy.t (* -------------------------------------------------------------------- *) +type select_filter_t = EcPath.path -> operator -> bool + +type select_t = + ((EcPath.path * etyarg list) * ty * unienv * sbody option) list + let select_op ?(hidden : bool = false) - ?(filter : EcPath.path -> operator -> bool = fun _ _ -> true) - (tvi : tvar_inst option) + ?(filter : select_filter_t = fun _ _ -> true) + ?(retty : ty option) + (tvi : tvi) (env : EcEnv.env) (name : qsymbol) (ue : unienv) - (sig_ : ty list * ty option) - : select_result list + (psig : dom) + : select_t = ignore hidden; (* FIXME *) let module D = EcDecl in - let (psig, retty) = sig_ in - let filter oppath op = (* Filter operator based on given type variables instanciation *) let filter_on_tvi = @@ -357,11 +1204,17 @@ let select_op let len = List.length lt in fun op -> let tparams = op.D.op_tparams in - List.length tparams = len + List.length tparams = len && + List.for_all2 + (fun (_, tcs) (_, tcw) -> + match tcw with + | None -> true + | Some tcw -> List.length tcs = List.length tcw) + tparams lt | Some (TVInamed ls) -> fun op -> - let tparams = List.map EcIdent.name op.D.op_tparams in - let tparams = Ssym.of_list tparams in + let tparams = List.map (fst_map EcIdent.name) op.D.op_tparams in + let tparams = Msym.of_list tparams in List.for_all (fun (x, _) -> Msym.mem x tparams) ls in @@ -374,26 +1227,54 @@ let select_op let subue = UniEnv.copy ue in try - let (tip, tvs) = UniEnv.openty_r subue op.D.op_tparams tvi in - let top = ty_subst tip op.D.op_ty in + let UniEnv.{ subst = tip_full; args; params = oparams } = + UniEnv.opentvi ~op_name:(EcPath.basename path) + subue op.D.op_tparams tvi in + let tip = f_subst_init ~tv:(Mid.map fst tip_full) () in + + let top = EcCoreSubst.ty_subst tip op.D.op_ty in let texpected = tfun_expected subue ?retty psig in (try unify env subue top texpected with UnificationFailure _ -> raise E.Failure); - let bd = + (* After type unification has pinned the carrier(s), try to + disambiguate any TC witnesses by op-name preservation along + parent walks. This is what lets [(+)<:comring>] pick the + addgroup walk uniquely when [comring] inherits from both + [addgroup] and [mulmonoid with (+) = ( * )]. *) + disambiguate_op_witnesses env subue + (EcPath.basename path) oparams args; + + let bd = match op.D.op_kind with | OB_nott nt -> let substnt () = - let xs = List.map (snd_map (ty_subst tip)) nt.D.ont_args in - let es = e_subst tip in + (* Substitute tparams (both type and TC-witness univars + bound during unification) into the abbrev body. We + pass [~tw] alongside [~tv] so [TCIAbstract \`Var] + witnesses captured at abbrev-definition time get + rewritten through the tparam => etyarg map; without + it the body keeps stale [\`Var] references to the + abbrev's tparams. [~tw_uni] resolves [TCIUni] + placeholders left over from [opentvi]; without it the + body prints with uninferrable [#a[#b]] witnesses. *) + let s = + f_subst_init + ~tv:(Mid.map fst tip_full) + ~tw:(Mid.map snd tip_full) + ~tw_uni:(UniEnv.tw_assubst subue) + () in + let xs = List.map (snd_map (ty_subst s)) nt.D.ont_args in + let es = e_subst s in let bd = es nt.D.ont_body in (xs, bd) in Some (Lazy.from_fun substnt) | _ -> None + in - in Some ((path, tvs), top, subue, bd) + Some ((path, args), top, subue, bd) with E.Failure -> None diff --git a/src/ecUnify.mli b/src/ecUnify.mli index d19596eb6b..50342e0df7 100644 --- a/src/ecUnify.mli +++ b/src/ecUnify.mli @@ -1,53 +1,87 @@ (* -------------------------------------------------------------------- *) -open EcUid +open EcIdent open EcSymbols -open EcPath open EcTypes +open EcAst open EcDecl -(* -------------------------------------------------------------------- *) -exception UnificationFailure of [`TyUni of ty * ty] -exception UninstantiateUni +(* ==================================================================== *) +type problem = [ + | `TyUni of ty * ty + | `TcTw of tcwitness * tcwitness + | `TcCtt of EcAst.tcuni * ty * typeclass +] + +type uniflags = { tyvars: bool; tcvars: bool; } + +exception UnificationFailure of problem +exception UninstanciateUni of uniflags + +(* Raised by the unifier's By-args strategy when a typeclass with + ground arguments has multiple matching instances and no further + unification can disambiguate. The first field is the offending + typeclass; the second is the list of candidate instance paths. *) +exception AmbiguousTcInstance of typeclass * EcPath.path list type unienv +type petyarg = ty option * tcwitness option list option + type tvar_inst = -| TVIunamed of ty list -| TVInamed of (EcSymbols.symbol * ty) list +| TVIunamed of petyarg list +| TVInamed of (EcSymbols.symbol * petyarg) list type tvi = tvar_inst option -type uidmap = uid -> ty option + +val tvi_unamed : etyarg list -> tvar_inst module UniEnv : sig - val create : ty_params option -> unienv + type opened = { + subst : etyarg Mid.t; + params : (ty * typeclass list) list; + args : etyarg list; + } + + val create : (EcIdent.t * typeclass list) list option -> unienv + val push : (EcIdent.t * typeclass list) -> unienv -> unit val copy : unienv -> unienv (* constant time *) val restore : dst:unienv -> src:unienv -> unit (* constant time *) + val xfresh : ?op_name:symbol -> ?tcs:(typeclass * EcTypes.tcwitness option) list -> ?ty:ty -> unienv -> etyarg val fresh : ?ty:ty -> unienv -> ty val getnamed : unienv -> symbol -> EcIdent.t val repr : unienv -> ty -> ty - val opentvi : unienv -> ty_params -> tvi -> ty EcIdent.Mid.t - val openty : unienv -> ty_params -> tvi -> ty -> ty * ty list - val opentys : unienv -> ty_params -> tvi -> ty list -> ty list * ty list + val opentvi : ?op_name:symbol -> unienv -> ty_params -> tvi -> opened + val openty : unienv -> ty_params -> tvi -> ty -> ty * opened + val opentys : unienv -> ty_params -> tvi -> ty list -> ty list * opened val closed : unienv -> bool - val close : unienv -> ty Muid.t - val assubst : unienv -> ty Muid.t + val xclosed : unienv -> uniflags option + val close : unienv -> ty TyUni.Muid.t + val assubst : unienv -> ty TyUni.Muid.t + val tw_assubst : unienv -> tcwitness TcUni.Muid.t val tparams : unienv -> ty_params + + (* Drain the pending TC-constraint queue, attempting to resolve every + [TcCtt] problem currently parked. Useful before TC-op reduction + attempts (e.g. in matcher's [try_delta]) where a [TCIUni] witness + needs to be committed before [tc_core_reduce] can fire. *) + val flush_tc_problems : EcEnv.env -> unienv -> unit end -val unify : EcEnv.env -> unienv -> ty -> ty -> unit +val unify : EcEnv.env -> unienv -> ty -> ty -> unit +val unify_tcw : EcEnv.env -> unienv -> tcwitness -> tcwitness -> unit +val unify_etyarg : EcEnv.env -> unienv -> etyarg -> etyarg -> unit val tfun_expected : unienv -> ?retty:ty -> EcTypes.ty list -> EcTypes.ty type sbody = ((EcIdent.t * ty) list * expr) Lazy.t -type select_result = (EcPath.path * ty list) * ty * unienv * sbody option - val select_op : ?hidden:bool - -> ?filter:(path -> operator -> bool) + -> ?filter:(EcPath.path -> operator -> bool) + -> ?retty:ty -> tvi -> EcEnv.env -> qsymbol -> unienv - -> dom * ty option - -> select_result list + -> dom + -> ((EcPath.path * etyarg list) * ty * unienv * sbody option) list diff --git a/src/ecUserMessages.ml b/src/ecUserMessages.ml index 65a874e64c..81ab00d4e8 100644 --- a/src/ecUserMessages.ml +++ b/src/ecUserMessages.ml @@ -1,8 +1,8 @@ (* -------------------------------------------------------------------- *) open EcSymbols -open EcUid open EcPath open EcUtils +open EcAst open EcTypes open EcCoreSubst open EcEnv @@ -21,6 +21,7 @@ let set_ppo (newppo : pp_options) = module TypingError : sig open EcTyping + val pp_uniflags : Format.formatter -> EcUnify.uniflags -> unit val pp_fxerror : env -> Format.formatter -> fxerror -> unit val pp_tyerror : env -> Format.formatter -> tyerror -> unit val pp_cnv_failure : env -> Format.formatter -> tymod_cnv_failure -> unit @@ -30,6 +31,16 @@ module TypingError : sig end = struct open EcTyping + let pp_uniflags (fmt : Format.formatter) ({ tyvars; tcvars; } : EcUnify.uniflags) = + let msg = + match tyvars, tcvars with + | false, false -> None + | true, false -> Some "type" + | false, true -> Some "type-class" + | true, true -> Some "type&type-class" in + + Option.iter (Format.fprintf fmt "%s") msg + let pp_mismatch_funsig env0 fmt error = let ppe0 = EcPrinting.PPEnv.ofenv env0 in @@ -275,8 +286,10 @@ end = struct | UniVarNotAllowed -> msg "type place holders not allowed" - | FreeTypeVariables -> - msg "this expression contains free type variables" + | FreeUniVariables infos -> + msg + "this expression contains free %a variables" + pp_uniflags infos | TypeVarNotAllowed -> msg "type variables not allowed" @@ -362,6 +375,14 @@ end = struct | TypeClassMismatch -> msg "Type-class unification failure" + | TypeClassAmbiguous (tc, paths) -> + msg "ambiguous typeclass instance for @[%a@]@\n" + (EcPrinting.pp_typeclass env) tc; + msg " candidates:@\n"; + List.iter (fun p -> + msg " %a@\n" (EcPrinting.pp_axname env) p) + paths + | TypeModMismatch(mp, mt, err) -> msg "the module %a does not have the module type %a:@\n" (EcPrinting.pp_topmod env) mp @@ -391,7 +412,7 @@ end = struct | MultipleOpMatch (name, tys, matches) -> begin let uvars = List.map Tuni.univars tys in - let uvars = List.fold_left Suid.union Suid.empty uvars in + let uvars = List.fold_left TyUni.Suid.union TyUni.Suid.empty uvars in begin match tys with | [] -> @@ -409,7 +430,7 @@ end = struct let pp_op fmt ((op, inst), subue) = let uidmap = EcUnify.UniEnv.assubst subue in - let inst = Tuni.subst_dom uidmap inst in + let inst = Tuni.subst_dom uidmap (List.fst inst) in begin match inst with | [] -> @@ -422,8 +443,8 @@ end = struct end; let myuvars = List.map Tuni.univars inst in - let myuvars = List.fold_left Suid.union uvars myuvars in - let myuvars = Suid.elements myuvars in + let myuvars = List.fold_left TyUni.Suid.union uvars myuvars in + let myuvars = TyUni.Suid.elements myuvars in let uidmap = EcUnify.UniEnv.assubst subue in let tysubst = ty_subst (Tuni.subst uidmap) in @@ -570,6 +591,14 @@ end = struct | LvMapOnNonAssign -> msg "map-style left-value cannot be used with assignments" + | TCArgsCountMismatch (_, typarams, tys) -> + msg "typeclass expects %d arguments, got %d" + (List.length typarams) (List.length tys) + + | CannotInferTC (ty, tc) -> + msg "cannot infer typeclass `%a' for type `%a'" + (EcPrinting.pp_typeclass env) tc pp_type ty + | NoDefaultMemRestr -> msg "no default sign for memory restriction. Use '+' or '-', or \ set the %s pragma to retrieve the old behaviour" @@ -710,8 +739,10 @@ end = struct let pp_tperror (env : env) fmt = function | TPE_Typing e -> TypingError.pp_tyerror env fmt e - | TPE_TyNotClosed -> - Format.fprintf fmt "this predicate type contains free type variables" + | TPE_TyNotClosed infos -> + Format.fprintf fmt + "this predicate type contains free %a variables" + TypingError.pp_uniflags infos | TPE_DuplicatedConstr x -> Format.fprintf fmt "duplicated constructor name: `%s'" x end @@ -730,8 +761,10 @@ end = struct match error with | NTE_Typing e -> TypingError.pp_tyerror env fmt e - | NTE_TyNotClosed -> - msg "this notation type contains free type variables" + | NTE_TyNotClosed infos -> + msg + "this notation type contains free %a variables" + TypingError.pp_uniflags infos | NTE_DupIdent -> msg "an ident is bound several time" | NTE_UnknownBinder x -> @@ -1055,6 +1088,14 @@ let pp fmt exn = | EcLowGoal.Apply.NoInstance e -> pp_apply_error fmt e + | EcUnify.AmbiguousTcInstance (tc, paths) -> + Format.fprintf fmt "ambiguous typeclass instance for "; + Format.fprintf fmt "@[%s@]@\n" (EcPath.tostring tc.tc_name); + Format.fprintf fmt " candidates:@\n"; + List.iter (fun p -> + Format.fprintf fmt " %s@\n" (EcPath.tostring p)) + paths + | _ -> raise exn (* -------------------------------------------------------------------- *) diff --git a/src/ecUserMessages.mli b/src/ecUserMessages.mli index efe97e0efc..97d3e0d10b 100644 --- a/src/ecUserMessages.mli +++ b/src/ecUserMessages.mli @@ -14,6 +14,7 @@ val set_ppo : pp_options -> unit module TypingError : sig open EcTyping + val pp_uniflags : Format.formatter -> EcUnify.uniflags -> unit val pp_tyerror : env -> Format.formatter -> tyerror -> unit val pp_cnv_failure : env -> Format.formatter -> tymod_cnv_failure -> unit val pp_mismatch_funsig : env -> Format.formatter -> mismatch_funsig -> unit diff --git a/src/ecUtils.ml b/src/ecUtils.ml index 0cd31828dd..e3b42f8c42 100644 --- a/src/ecUtils.ml +++ b/src/ecUtils.ml @@ -116,6 +116,12 @@ type 'a tuple8 = 'a * 'a * 'a * 'a * 'a * 'a * 'a * 'a type 'a tuple9 = 'a * 'a * 'a * 'a * 'a * 'a * 'a * 'a * 'a type 'a pair = 'a * 'a +(* -------------------------------------------------------------------- *) +module SmartPair = struct + let mk ((a, b) as p) a' b' = + if a == a' && b == b' then p else (a', b') +end + (* -------------------------------------------------------------------- *) let t2_map (f : 'a -> 'b) (x, y) = (f x, f y) @@ -487,6 +493,17 @@ module List = struct | None -> failwith "List.last" | Some x -> x + let betail = + let rec aux (acc : 'a list) (s : 'a list) = + match s, acc with + | [], [] -> + failwith "List.betail" + | [], v :: vs-> + List.rev vs, v + | x :: xs, _ -> + aux (x :: acc) xs + in fun s -> aux [] s + let mbfilter (p : 'a -> bool) (s : 'a list) = match s with [] | [_] -> s | _ -> List.filter p s diff --git a/src/ecUtils.mli b/src/ecUtils.mli index 3141d814b6..e42949d99e 100644 --- a/src/ecUtils.mli +++ b/src/ecUtils.mli @@ -64,6 +64,11 @@ type 'a tuple8 = 'a * 'a * 'a * 'a * 'a * 'a * 'a * 'a type 'a tuple9 = 'a * 'a * 'a * 'a * 'a * 'a * 'a * 'a * 'a type 'a pair = 'a tuple2 +(* -------------------------------------------------------------------- *) +module SmartPair : sig + val mk : 'a * 'b -> 'a -> 'b -> 'a * 'b +end + (* -------------------------------------------------------------------- *) val in_seq1: ' a -> 'a list @@ -285,6 +290,7 @@ module List : sig val min : ?cmp:('a -> 'a -> int) -> 'a list -> 'a val max : ?cmp:('a -> 'a -> int) -> 'a list -> 'a + val betail : 'a list -> 'a list * 'a val destruct : 'a list -> 'a * 'a list val nth_opt : 'a list -> int -> 'a option val mbfilter : ('a -> bool) -> 'a list -> 'a list diff --git a/src/phl/ecPhlCond.ml b/src/phl/ecPhlCond.ml index baf9f449e6..aef86b5677 100644 --- a/src/phl/ecPhlCond.ml +++ b/src/phl/ecPhlCond.ml @@ -273,8 +273,8 @@ let t_equiv_match_same_constr tc = let bhl = List.map (fst_map EcIdent.fresh) cl in let bhr = List.map (fst_map EcIdent.fresh) cr in let cop = EcPath.pqoname (EcPath.prefix pl) c in - let copl = f_op cop tyl (toarrow (List.snd cl) fl.inv.f_ty) in - let copr = f_op cop tyr (toarrow (List.snd cr) fr.inv.f_ty) in + let copl = f_op_tc cop tyl (toarrow (List.snd cl) fl.inv.f_ty) in + let copr = f_op_tc cop tyr (toarrow (List.snd cr) fr.inv.f_ty) in let lhs = map_ts_inv1 (fun fl -> f_eq fl (f_app copl (List.map (curry f_local) bhl) fl.f_ty)) fl in let lhs = map_ts_inv1 (f_exists (List.map (snd_map gtty) bhl)) lhs in @@ -290,8 +290,8 @@ let t_equiv_match_same_constr tc = let sb, bhl = add_elocals sb cl in let sb, bhr = add_elocals sb cr in let cop = EcPath.pqoname (EcPath.prefix pl) c in - let copl = f_op cop tyl (toarrow (List.snd cl) fl.inv.f_ty) in - let copr = f_op cop tyr (toarrow (List.snd cr) fr.inv.f_ty) in + let copl = f_op_tc cop tyl (toarrow (List.snd cl) fl.inv.f_ty) in + let copr = f_op_tc cop tyr (toarrow (List.snd cr) fr.inv.f_ty) in let f_ands_simpl' f = f_ands_simpl (List.tl f) (List.hd f) in let pre = map_ts_inv f_ands_simpl' [es_pr es; map_ts_inv1 (fun fl -> f_eq fl (f_app copl (List.map (curry f_local) bhl) fl.f_ty)) fl; @@ -354,8 +354,8 @@ let t_equiv_match_eq tc = sb cl cr in let cop = EcPath.pqoname (EcPath.prefix pl) c in - let copl = f_op cop tyl (toarrow (List.snd cl) fl.inv.f_ty) in - let copr = f_op cop tyr (toarrow (List.snd cr) fr.inv.f_ty) in + let copl = f_op_tc cop tyl (toarrow (List.snd cl) fl.inv.f_ty) in + let copr = f_op_tc cop tyr (toarrow (List.snd cr) fr.inv.f_ty) in let f_ands_simpl' f = f_ands_simpl (List.tl f) (List.hd f) in let pre = map_ts_inv f_ands_simpl' [ es_pr es; map_ts_inv1 (fun fl -> f_eq fl (f_app copl (List.map (curry f_local) bh) fl.f_ty)) fl; diff --git a/src/phl/ecPhlEqobs.ml b/src/phl/ecPhlEqobs.ml index d7f9b0d1cd..b92d91c13e 100644 --- a/src/phl/ecPhlEqobs.ml +++ b/src/phl/ecPhlEqobs.ml @@ -266,7 +266,7 @@ and i_eqobs_in il ir sim local (eqo:Mpv2.t) = let typr, _, tyinstr = oget (EcEnv.Ty.get_top_decl el.e_ty env) in let test = EcPath.p_equal typl typr && - List.for_all2 (EcReduction.EqTest.for_type env) tyinstl tyinstr in + List.for_all2 (EcReduction.EqTest.for_etyarg env) tyinstl tyinstr in if not test then raise EqObsInError; let rsim = ref sim in let doit eqs1 (argsl,sl) (argsr, sr) = diff --git a/src/phl/ecPhlHiCond.ml b/src/phl/ecPhlHiCond.ml index cb19956841..602eb4391d 100644 --- a/src/phl/ecPhlHiCond.ml +++ b/src/phl/ecPhlHiCond.ml @@ -1,10 +1,11 @@ (* -------------------------------------------------------------------- *) open EcUtils +open EcAst +open EcMatching.Position open EcCoreGoal open EcLowGoal open EcLowPhlGoal open EcPhlCond -open EcMatching.Position (* -------------------------------------------------------------------- *) let process_cond (info : EcParsetree.pcond_info) tc = diff --git a/src/phl/ecPhlPrRw.ml b/src/phl/ecPhlPrRw.ml index 8709cce9e9..899fb9dd09 100644 --- a/src/phl/ecPhlPrRw.ml +++ b/src/phl/ecPhlPrRw.ml @@ -112,7 +112,7 @@ let p_BRA_big = EcPath.fromqsymbol (p_BRA, "big") let destr_pr_has pr = let m = pr.pr_event.m in match pr.pr_event.inv.f_node with - | Fapp ({ f_node = Fop(op, [ty_elem]) }, [f_f; f_l]) -> + | Fapp ({ f_node = Fop(op, [(ty_elem, _)]) }, [f_f; f_l]) -> if EcPath.p_equal p_list_has op && not (Mid.mem m f_l.f_fv) then Some(ty_elem, {m;inv=f_f}, f_l) else None diff --git a/src/phl/ecPhlRCond.ml b/src/phl/ecPhlRCond.ml index 81a78744ac..c169fc61bf 100644 --- a/src/phl/ecPhlRCond.ml +++ b/src/phl/ecPhlRCond.ml @@ -172,7 +172,7 @@ module LowMatch = struct in (x, xty)) cvars in let vars = List.map (curry f_local) names in let cty = toarrow (List.snd names) f.inv.f_ty in - let po = f_op cname (List.snd tyinst) cty in + let po = f_op_tc cname (List.snd tyinst) cty in let po = f_app po vars f.inv.f_ty in map_ss_inv1 (f_exists (List.map (snd_map gtty) names)) (map_ss_inv2 f_eq f {m;inv=po}) in @@ -201,7 +201,7 @@ module LowMatch = struct let epr, asgn = if frame then begin let vars = List.map (fun (pv, ty) -> f_pvar pv ty (fst me)) pvs in - let epr = f_op cname (List.snd tyinst) f.inv.f_ty in + let epr = f_op_tc cname (List.snd tyinst) f.inv.f_ty in let epr = map_ss_inv ~m:f.m (fun vars -> f_app epr vars f.inv.f_ty) vars in Some (map_ss_inv2 f_eq f epr), [] end else begin @@ -210,7 +210,7 @@ module LowMatch = struct (* FIXME: factorize out *) let rty = ttuple (List.snd cvars) in let proj = EcInductive.datatype_proj_path typ (EcPath.basename cname) in - let proj = e_op proj (List.snd tyinst) (tfun e.e_ty (toption rty)) in + let proj = e_op_tc proj (List.snd tyinst) (tfun e.e_ty (toption rty)) in let proj = e_app proj [e] (toption rty) in let proj = e_oget proj rty in i_asgn (lv, proj)) in diff --git a/src/phl/ecPhlRwEquiv.ml b/src/phl/ecPhlRwEquiv.ml index bc8c4ea7fe..ae92f68293 100644 --- a/src/phl/ecPhlRwEquiv.ml +++ b/src/phl/ecPhlRwEquiv.ml @@ -147,8 +147,8 @@ let process_rewrite_equiv info tc = let res = omap (fun v -> EcTyping.transexpcast subenv `InProc ue ret_ty v) pres in let es = e_subst (Tuni.subst (EcUnify.UniEnv.close ue)) in Some (List.map es args, omap (EcModules.lv_of_expr -| es) res) - with EcUnify.UninstantiateUni -> - EcTyping.tyerror (loc pargs) env EcTyping.FreeTypeVariables + with EcUnify.UninstanciateUni flags -> + EcTyping.tyerror (loc pargs) env (EcTyping.FreeUniVariables flags) end in diff --git a/src/phl/ecPhlSp.mli b/src/phl/ecPhlSp.mli index 3592d8caf5..1e2a383ac7 100644 --- a/src/phl/ecPhlSp.mli +++ b/src/phl/ecPhlSp.mli @@ -1,8 +1,8 @@ (* -------------------------------------------------------------------- *) +open EcUtils open EcParsetree open EcMatching.Position open EcCoreGoal.FApi -open EcUtils (* -------------------------------------------------------------------- *) val t_sp : (codegap1 doption) option -> backward diff --git a/src/phl/ecPhlWhile.ml b/src/phl/ecPhlWhile.ml index e1714332c4..cc6bbe80c0 100644 --- a/src/phl/ecPhlWhile.ml +++ b/src/phl/ecPhlWhile.ml @@ -370,7 +370,7 @@ module LossLess = struct | Fint z -> e_int z | Flocal x -> e_local x fp.f_ty - | Fop (p, tys) -> e_op p tys fp.f_ty + | Fop (p, tys) -> e_op_tc p tys fp.f_ty | Fapp (f, fs) -> e_app (aux f) (List.map aux fs) fp.f_ty | Ftuple fs -> e_tuple (List.map aux fs) | Fproj (f, i) -> e_proj (aux f) i fp.f_ty diff --git a/src/phl/ecPhlWp.mli b/src/phl/ecPhlWp.mli index c0017f5780..c386d92285 100644 --- a/src/phl/ecPhlWp.mli +++ b/src/phl/ecPhlWp.mli @@ -2,7 +2,6 @@ open EcUtils open EcParsetree open EcMatching.Position - open EcCoreGoal.FApi (* -------------------------------------------------------------------- *) diff --git a/subtypes/subtype.ec b/subtypes/subtype.ec new file mode 100644 index 0000000000..1f4c2f2535 --- /dev/null +++ b/subtypes/subtype.ec @@ -0,0 +1,107 @@ +(* ==================================================================== *) +subtype 'a word (n : int) = { + w : 'a list | size w = n +} + witness. + +op cat ['a] [n m : int] (x : {'a word n}) (y : {'a word m}) : {'a word (n+m)} = + x ++ y. + +==> (traduction) + +op cat ['a] (x : 'a word) (y : 'a word) : 'a word = + x ++ y. + +lemma cat_spec ['a] : + forall (n m : int) (x y : 'a word), + size x = n => size y = m => size (cat x y) = (n + m). + +op xor [n m : int] (w1 : {word n}) (w2 : {word m}) : {word (min (n, m))} = + ... + +lemma foo ['a] [n : int] (w1 w2 : {'a word n}) : + xor w1 w2 = xor w2 w1. + +op vectorize ['a] [n m : int] (w : {'a word (n * m)}) : {{'a word n} word m}. + +lemma vectorize_spec ['a] (w : 'a list) : size w = (n * m) => + size (vectorize w) = m + /\ (all (fun w' => size w' = n) (vectorize w)). + +-> Keeping information in application? Yes + -> should provide a syntax for giving the arguments + + {w : word 256} + + vectorize<:int, n = 4> w ==> infer: m = 64 + +-> What to do when the inference fails + 1. we reject (most likely) + 2. we open a goal + +-> In a proof script (apply: foo) or (rewrite foo) + 1. inference des dépendances (n, m, ...) + 2. décharger les conditions de bord (size w1 = n, size w2 = n) + +-> Goal + n : int + m : int + w1 : {word n} + w2 : {word m} + ==================================================================== + E[xor (cat w1 w2) (cat w2 w1)] + + rewrite foo + + n : int + m : int + w1 : {word n} + w2 : {word m} + ==================================================================== + E[xor (cat w2 w1) (cat w1 w2)] + + under condition: + exists p . size (cat w1 w2) = p /\ size (cat w2 w1) = p. + + ?p = size (cat w1 w2) + ?p = size (cat w2 w1) + +-> can be solved using a extended prolog-like engine + 1. declarations of variables (w1 : {word n}) (w2 : {word m}) + 2. prolog-like facts from operators types (-> ELPI) + 3. theories (ring / int) + +-> subtypes in procedures + + We can only depend on operators / constants. I.e. the following + program should be rejected: + + module M = { + var n : int + + proc f(x : {bool word M.n}) = { + } + } + + Question: + - What about dependent types in the type for results: + we reject programs if we cannot statically check the condition + - What about the logics? we have to patch them. + +(* ==================================================================== *) +all : 'a t * 'a -> bool + +axiom all_spec ['a] : forall (f : 'a t -> 'a) (s : 'a t), all (s, f s). + +nth ['a] 'a -> 'a list -> int -> 'a + +lemma nth_spec ['a] (x : 'a) (s : 'a list) (i : int) : + forall P, + (forall y, all<: 'a> (y, x) -> P y) -> + P x -> (forall y, all<: 'a list> (s, y) -> P y) -> P (nth x s i). + +ws : {word n} list + +nth<:word> witness ws 2 : word +nth<:{word n}> + +coercion : 'a word n -> 'a list diff --git a/tests/outline.ec b/tests/outline.ec index 3e06798c63..30587c0687 100644 --- a/tests/outline.ec +++ b/tests/outline.ec @@ -1,4 +1,4 @@ -require import AllCore. +require import AllCore Distr. type t = int. op dint : t distr. diff --git a/tests/tc-ko/ambiguous-instance.ec b/tests/tc-ko/ambiguous-instance.ec new file mode 100644 index 0000000000..6b170a94df --- /dev/null +++ b/tests/tc-ko/ambiguous-instance.ec @@ -0,0 +1,35 @@ +require import AllCore. + +(* Negative: two distinct instances of the same parametric typeclass + match the goal's args. The By-args strategy must report + "ambiguous typeclass instance" rather than degrading to a generic + "free variables" error at close time. *) +type class ['a, 'b] embed = { + op proj : embed -> 'a + op inj : 'b -> embed + axiom dummy : true +}. + +(* First instance: int * bool, with the natural projections. *) +op proj_pair_l (p : int * bool) : int = fst p. +op inj_pair_l (b : bool) : int * bool = (0, b). + +instance (int, bool) embed as pair_inst_l with (int * bool) + op proj = proj_pair_l + op inj = inj_pair_l. + +realize dummy by trivial. + +(* Second instance: bool * int, with swapped projections. Both match + (int, bool) embed. *) +op proj_pair_r (p : bool * int) : int = snd p. +op inj_pair_r (b : bool) : bool * int = (b, 0). + +instance (int, bool) embed as pair_inst_r with (bool * int) + op proj = proj_pair_r + op inj = inj_pair_r. + +realize dummy by trivial. + +(* Bare op: ambiguous, since both instances of (int, bool) embed match. *) +op test_ambiguous : int = proj (inj true). diff --git a/tests/tc-ko/bad-tvi.ec b/tests/tc-ko/bad-tvi.ec new file mode 100644 index 0000000000..d1a3159039 --- /dev/null +++ b/tests/tc-ko/bad-tvi.ec @@ -0,0 +1,23 @@ +require import AllCore. + +(* Negative: a TC-polymorphic lemma is instantiated at a type with no + matching instance. pf_check_tvi must reject this with the typed + error "type int does not satisfy typeclass constraint addmonoid". *) +type class addmonoid = { + op idm : addmonoid + op (+) : addmonoid -> addmonoid -> addmonoid + + axiom addmA : associative (+) + axiom addmC : commutative (+) + axiom add0m : left_id idm (+) +}. + +lemma idm_idem ['a <: addmonoid] (x : 'a) : idm + x = x. +proof. by apply add0m. qed. + +(* No instance for [int]. *) +lemma test : true. +proof. +have := idm_idem<:int> 0. +trivial. +qed. diff --git a/tests/tc-ko/diamond-coherence.ec b/tests/tc-ko/diamond-coherence.ec new file mode 100644 index 0000000000..9e653cbff8 --- /dev/null +++ b/tests/tc-ko/diamond-coherence.ec @@ -0,0 +1,34 @@ +require import AllCore. + +(* Negative: registering two instances on the same carrier where a + shared ancestor's ops would have to disagree must hard-error with + a "diamond coherence violation" message. *) + +type class parent = { + op my_f : parent + axiom ax : forall (x : parent), x = x +}. + +type class child <: parent = { + op my_g : child + axiom bx : forall (x : child), x = x +}. + +op f_zero : int = 0. +op f_one : int = 1. +op g_zero : int = 0. + +instance parent as parent_int with int + op my_f = f_zero. + +realize ax by trivial. + +(* This second instance binds parent's op my_f to f_one, which + conflicts with the existing parent_int instance binding it to + f_zero. Phase B coherence check must hard-error. *) +instance child as child_int with int + op my_f = f_one + op my_g = g_zero. + +realize ax by trivial. +realize bx by trivial. diff --git a/tests/tc-ko/underconstrained-axiom.ec b/tests/tc-ko/underconstrained-axiom.ec new file mode 100644 index 0000000000..5c5d90b714 --- /dev/null +++ b/tests/tc-ko/underconstrained-axiom.ec @@ -0,0 +1,19 @@ +require import AllCore. + +(* Negative: a typeclass body axiom uses a grandparent's TC operator + without pinning the carrier. The typer must reject with the typed + "axiom is type-ambiguous" message rather than the raw + UninstanciateUni anomaly. *) +type class base = { + op zero : base + axiom zero_eq : zero = zero +}. + +type class tc1 <: base = { + op f1 : tc1 -> tc1 + axiom f1_id : forall (x : tc1), f1 x = x +}. + +type class tc3 <: tc1 = { + axiom tc3_extra : zero = zero +}. diff --git a/tests/tc/basic.ec b/tests/tc/basic.ec new file mode 100644 index 0000000000..09c608e734 --- /dev/null +++ b/tests/tc/basic.ec @@ -0,0 +1,29 @@ +require import AllCore. + +(* TC declaration with axioms, polymorphic operators and lemmas *) +type class addmonoid = { + op idm : addmonoid + op (+) : addmonoid -> addmonoid -> addmonoid + + axiom addmA : associative (+) + axiom addmC : commutative (+) + axiom add0m : left_id idm (+) +}. + +(* Polymorphic op over a TC *) +op double ['a <: addmonoid] (x : 'a) = x + x. + +(* Polymorphic lemma using TC axioms *) +lemma addm0 ['a <: addmonoid] (x : 'a) : x + idm = x. +proof. by rewrite addmC add0m. qed. + +(* Section abstracting a TC-constrained type *) +section. + declare type t <: addmonoid. + + lemma double_id (x : t) : double x = x + x. + proof. by rewrite /double. qed. + + lemma id_double : double idm<:t> = idm. + proof. by rewrite /double add0m. qed. +end section. diff --git a/tests/tc/clone-with-instance.ec b/tests/tc/clone-with-instance.ec new file mode 100644 index 0000000000..a21c0e7bc0 --- /dev/null +++ b/tests/tc/clone-with-instance.ec @@ -0,0 +1,44 @@ +require import AllCore. + +(* Abstract theory parametrized by a TC carrier; cloning the theory + with a concrete carrier must thread the TC instance correctly. *) +type class addmonoid = { + op idm : addmonoid + op (+) : addmonoid -> addmonoid -> addmonoid + + axiom addmA : associative (+) + axiom addmC : commutative (+) + axiom add0m : left_id idm (+) +}. + +abstract theory T. + type t <: addmonoid. + + op double (x : t) : t = x + x. + + lemma double_idm : double idm = idm. + proof. by rewrite /double add0m. qed. +end T. + +(* Concrete instance for [int]. *) +op zero_int : int = 0. +op plus_int : int -> int -> int = Int.( + ). + +instance addmonoid as int_inst with int + op idm = zero_int + op (+) = plus_int. + +realize addmA by rewrite /plus_int; smt(). +realize addmC by rewrite /plus_int; smt(). +realize add0m by rewrite /plus_int /zero_int; smt(). + +(* Clone T with t = int. The carrier's TC constraint is satisfied by + int_inst. The cloned theory's lemmas/ops are usable. *) +clone T as TI with type t = int. + +(* Cloned operator [TI.double] is well-typed at the concrete carrier. *) +op test_op : int = TI.double zero_int. + +(* Cloned op reduces under [delta_tc] using the resolved concrete instance. *) +lemma test_double : TI.double zero_int = plus_int zero_int zero_int. +proof. by rewrite /TI.double. qed. diff --git a/tests/tc/clone.ec b/tests/tc/clone.ec new file mode 100644 index 0000000000..1e4c1b260c --- /dev/null +++ b/tests/tc/clone.ec @@ -0,0 +1,24 @@ +require import AllCore. + +(* Cloning a theory containing a typeclass and a TC-polymorphic lemma *) +abstract theory Algebra. + type class addmonoid = { + op idm : addmonoid + op (+) : addmonoid -> addmonoid -> addmonoid + + axiom addmA : associative (+) + axiom addmC : commutative (+) + axiom add0m : left_id idm (+) + }. + + lemma addm0 ['a <: addmonoid] (x : 'a) : x + idm = x. + proof. by rewrite addmC add0m. qed. +end Algebra. + +(* The cloned typeclass and lemma are usable in the cloned theory *) +clone Algebra as A2. + +op test ['a <: A2.addmonoid] (x : 'a) = A2.(+) x A2.idm. + +lemma test_eq ['a <: A2.addmonoid] (x : 'a) : test x = x. +proof. rewrite /test. exact A2.addm0. qed. diff --git a/tests/tc/declare-type.ec b/tests/tc/declare-type.ec new file mode 100644 index 0000000000..299e8f1455 --- /dev/null +++ b/tests/tc/declare-type.ec @@ -0,0 +1,27 @@ +require import AllCore. + +(* A section using [declare type t <: tc] for an abstract carrier; the + developed operators survive section close as TC-polymorphic. *) +type class addmonoid = { + op idm : addmonoid + op (+) : addmonoid -> addmonoid -> addmonoid + + axiom addmA : associative (+) + axiom addmC : commutative (+) + axiom add0m : left_id idm (+) +}. + +section. + declare type t <: addmonoid. + + op double (x : t) : t = x + x. + + lemma double_idm : double idm = idm. + proof. by rewrite /double add0m. qed. +end section. + +(* After section close: [double] becomes TC-polymorphic. *) +op test_call ['a <: addmonoid] (x : 'a) : 'a = double x. + +lemma test_idm ['a <: addmonoid] : double<:'a> idm = idm. +proof. by apply double_idm. qed. diff --git a/tests/tc/diamond.ec b/tests/tc/diamond.ec new file mode 100644 index 0000000000..1a72ece68a --- /dev/null +++ b/tests/tc/diamond.ec @@ -0,0 +1,43 @@ +require import AllCore. + +(* Diamond inheritance: + base + / \ + tc1 tc2 + \ / + tc3 + Verify that ancestors are correctly walked through both branches and + that the SMT auto-axiom inclusion does not double-pull base axioms. *) + +type class base = { + op zero : base + axiom zero_idem : forall (x : base), x = x +}. + +type class tc1 <: base = { + op f1 : tc1 -> tc1 + axiom f1_id : forall (x : tc1), f1 x = x +}. + +type class tc2 <: base = { + op f2 : tc2 -> tc2 + axiom f2_id : forall (x : tc2), f2 x = x +}. + +(* tc3 inherits from tc1 — diamond closes here only on the tc1 side. *) +type class tc3 <: tc1 = { + op f3 : tc3 -> tc3 + axiom f3_id : forall (x : tc3), f3 x = x +}. + +(* Polymorphic lemma: tc3 carrier must satisfy the parent f1_id (lift=1). *) +lemma f1_via_tc3 ['a <: tc3] (x : 'a) : f1 x = x. +proof. by apply f1_id. qed. + +(* SMT auto-includes ancestor axioms — base, tc1, tc3 should all be + reachable from tc3 without duplication. *) +lemma f3_smt ['a <: tc3] (x : 'a) : f3 x = x. +proof. smt(). qed. + +lemma f1_smt ['a <: tc3] (x : 'a) : f1 x = x. +proof. smt(). qed. diff --git a/tests/tc/explicit-tvi.ec b/tests/tc/explicit-tvi.ec new file mode 100644 index 0000000000..cabe7f2a08 --- /dev/null +++ b/tests/tc/explicit-tvi.ec @@ -0,0 +1,34 @@ +require import AllCore. + +(* Explicit type-instantiation [<: int>] of a polymorphic-over-TC lemma + must pick up the matching named instance and succeed. *) +type class addmonoid = { + op idm : addmonoid + op (+) : addmonoid -> addmonoid -> addmonoid + + axiom addmA : associative (+) + axiom addmC : commutative (+) + axiom add0m : left_id idm (+) +}. + +op zero_int : int = 0. +op plus_int : int -> int -> int = Int.( + ). + +instance addmonoid as int_inst with int + op idm = zero_int + op (+) = plus_int. + +realize addmA by rewrite /plus_int; smt(). +realize addmC by rewrite /plus_int; smt(). +realize add0m by rewrite /plus_int /zero_int; smt(). + +lemma idm_idem ['a <: addmonoid] (x : 'a) : idm + x = x. +proof. by apply add0m. qed. + +(* Explicit TVI: should pick int_inst. *) +lemma test1 (n : int) : zero_int + n = n. +proof. by apply (idm_idem<:int> n). qed. + +(* No TVI: should also work via unification-driven instance resolution. *) +lemma test2 (n : int) : zero_int + n = n. +proof. by apply (idm_idem n). qed. diff --git a/tests/tc/grandparent-op.ec b/tests/tc/grandparent-op.ec new file mode 100644 index 0000000000..e9e54b3cc1 --- /dev/null +++ b/tests/tc/grandparent-op.ec @@ -0,0 +1,27 @@ +require import AllCore. + +(* Using a grandparent's TC operator inside a typeclass body. The + carrier is implicit, so we must pin it via [<:carrier>] when the + operator's argument types do not otherwise force the carrier. *) +type class base = { + op zero : base + axiom zero_eq : zero = zero +}. + +type class tc1 <: base = { + op f1 : tc1 -> tc1 + axiom f1_id : forall (x : tc1), f1 x = x +}. + +(* Without explicit tvi, the typer cannot infer the carrier and emits a + clear "type-ambiguous" error. The standard fix is to pin the + carrier with [<:carrier>]. *) +type class tc3 <: tc1 = { + axiom tc3_extra : (zero<:tc3>) = zero +}. + +(* When the operator's argument forces the carrier, no explicit tvi is + needed: [zero = x] implies [zero : tc3_alt] from [x : tc3_alt]. *) +type class tc3_alt <: tc1 = { + axiom tc3_via_arg : forall (x : tc3_alt), zero = x => x = zero +}. diff --git a/tests/tc/inheritance.ec b/tests/tc/inheritance.ec new file mode 100644 index 0000000000..07805d4733 --- /dev/null +++ b/tests/tc/inheritance.ec @@ -0,0 +1,29 @@ +require import AllCore. + +(* Multi-level subclass chain: addmonoid <- group, with a polymorphic + lemma at the parent level used through the subclass. *) +type class addmonoid = { + op idm : addmonoid + op (+) : addmonoid -> addmonoid -> addmonoid + + axiom addmA : associative (+) + axiom addmC : commutative (+) + axiom add0m : left_id idm (+) +}. + +type class group <: addmonoid = { + op opp : group -> group + axiom addmN : left_inverse idm opp (+)<:group> +}. + +(* Polymorphic lemma over [addmonoid] *) +lemma addm0 ['a <: addmonoid] (x : 'a) : x + idm = x. +proof. by rewrite addmC add0m. qed. + +(* The same lemma should be usable under the [group] subclass — the + ancestor walk surfaces the [addmonoid] constraint. *) +lemma addm0_via_group ['a <: group] (x : 'a) : x + idm = x. +proof. by apply addm0. qed. + +(* And direct use of the parent operator on a subclass-bound value. *) +op test ['a <: group] (x : 'a) : 'a = x + idm + opp x. diff --git a/tests/tc/instance.ec b/tests/tc/instance.ec new file mode 100644 index 0000000000..473e44879f --- /dev/null +++ b/tests/tc/instance.ec @@ -0,0 +1,29 @@ +require import AllCore Bool. + +(* TC + named instance for a concrete type *) +type class addmonoid = { + op idm : addmonoid + op (+) : addmonoid -> addmonoid -> addmonoid + + axiom addmA : associative (+) + axiom addmC : commutative (+) + axiom add0m : left_id idm (+) +}. + +instance addmonoid as bool_xor with bool + op idm = false + op (+) = (^^). + +realize addmA by smt(). +realize addmC by smt(). +realize add0m by smt(). + +(* Use the polymorphic ops at the concrete instance type. The instance + resolution must succeed (otherwise the typing would fail). *) +op test (x : bool) = x + idm<:bool>. + +(* Unnamed instance also works (auto-named) *) +type class group <: addmonoid = { + op opp : group -> group + axiom addmN : left_inverse idm opp (+)<:group> +}. diff --git a/tests/tc/multi-instance.ec b/tests/tc/multi-instance.ec new file mode 100644 index 0000000000..6e9c3c154e --- /dev/null +++ b/tests/tc/multi-instance.ec @@ -0,0 +1,29 @@ +require import AllCore. + +(* Test that multiple named instances for the same TC at different + types coexist without interference. *) +type class addmonoid = { + op idm : addmonoid + op (+) : addmonoid -> addmonoid -> addmonoid + + axiom addmA : associative (+) + axiom addmC : commutative (+) + axiom add0m : left_id idm (+) +}. + +(* Instance for [int] *) +op zero_int : int = 0. +op plus_int : int -> int -> int = Int.( + ). + +instance addmonoid as int_inst with int + op idm = zero_int + op (+) = plus_int. + +realize addmA by rewrite /plus_int; smt(). +realize addmC by rewrite /plus_int; smt(). +realize add0m by rewrite /plus_int /zero_int; smt(). + +(* Both instance types coexist; explicit instantiation picks the right one *) +op test_int : int = idm<:int>. + +lemma test_int_eq : test_int = zero_int by rewrite /test_int; smt(). diff --git a/tests/tc/multi-param-bare-ops.ec b/tests/tc/multi-param-bare-ops.ec new file mode 100644 index 0000000000..ede6843540 --- /dev/null +++ b/tests/tc/multi-param-bare-ops.ec @@ -0,0 +1,36 @@ +require import AllCore. + +(* Mode #3: bare ops on a parametric-carrier multi-parameter typeclass. + The unifier's By-args strategy infers the carrier from the (ground) + type-class arguments when there is a unique matching instance. *) +type class ['a, 'b] embed = { + op proj : embed -> 'a + op inj : 'b -> embed + axiom dummy : true +}. + +(* Concrete instance: pair (int, bool). *) +op proj_pair (p : int * bool) : int = fst p. +op inj_pair (b : bool) : int * bool = (0, b). + +instance (int, bool) embed as pair_inst with (int * bool) + op proj = proj_pair + op inj = inj_pair. + +realize dummy by trivial. + +(* Bare ops: the carrier (int * bool) is inferred from the (int, bool) + embed instance — no explicit tvi needed. *) +op test_bare : int = proj (inj true). + +(* Same shape inside a lemma. *) +lemma round_trip (b : bool) : proj (inj b) = (0, b).`1. +proof. by rewrite /inj_pair /proj_pair. qed. + +(* Even when the user only constrains the result type, the args of the + typeclass propagate from the unique matching instance. *) +op test_proj_only (s : int * bool) : int = proj s. + +(* And when only the source type is fixed: the carrier and target are + inferred from the unique embed instance. *) +op test_inj_only (b : bool) : int * bool = inj b. diff --git a/tests/tc/multi-param.ec b/tests/tc/multi-param.ec new file mode 100644 index 0000000000..29cb5f50e7 --- /dev/null +++ b/tests/tc/multi-param.ec @@ -0,0 +1,44 @@ +require import AllCore. + +(* Multi-parameter typeclass: [embed] takes two type parameters + ['a, 'b], indexing the source/target of the embedding. *) +type class ['a, 'b] embed = { + op proj : embed -> 'a + op inj : 'b -> embed + + axiom proj_inj : + forall (x : 'a) (y : 'b), proj (inj y) = x => proj (inj y) = x +}. + +(* Polymorphic-over-multi-param lemma. The polymorphic body still needs + an explicit tvi: the carrier is a type parameter ['c], so there is + no concrete instance to drive By-args inference. *) +lemma round_trip + ['a, 'b, 'c <: ('a, 'b) embed] + (x : 'a) (y : 'b) : + proj<:'a, 'b, 'c> (inj<:'a, 'b, 'c> y) = x => + proj<:'a, 'b, 'c> (inj<:'a, 'b, 'c> y) = x. +proof. by apply proj_inj. qed. + +(* Concrete instance: pair (int, bool) carrying both. *) +op proj_pair (p : int * bool) : int = fst p. +op inj_pair (b : bool) : int * bool = (0, b). + +instance (int, bool) embed as pair_inst with (int * bool) + op proj = proj_pair + op inj = inj_pair. + +realize proj_inj by trivial. + +(* The instance specializes both type parameters. Both forms work: + the helper-op form and the bare TC op form. *) +op test_proj : int = proj_pair (inj_pair true). +op test_via_tc : int = proj (inj true). + +(* Polymorphic lemma applied at the concrete instance. The body uses + explicit tvi because the apply target is the polymorphic + [round_trip], not a TC op directly. *) +lemma round_trip_int (x : int) (y : bool) : + proj<:int, bool, (int * bool)> (inj<:int, bool, (int * bool)> y) = x => + proj<:int, bool, (int * bool)> (inj<:int, bool, (int * bool)> y) = x. +proof. by apply (round_trip<:int, bool, (int * bool)>). qed. diff --git a/tests/tc/parametric.ec b/tests/tc/parametric.ec new file mode 100644 index 0000000000..8d7c5d6a6d --- /dev/null +++ b/tests/tc/parametric.ec @@ -0,0 +1,23 @@ +require import AllCore. + +(* Parametric typeclass: a class indexed by another typeclass. *) +type class addmonoid = { + op idm : addmonoid + op (+) : addmonoid -> addmonoid -> addmonoid + + axiom addmA : associative (+) + axiom addmC : commutative (+) + axiom add0m : left_id idm (+) +}. + +(* An action of an [addmonoid] on a carrier *) +type class ['a <: addmonoid] action = { + op act : 'a -> action -> action + + axiom act_id : forall (x : action), act idm<:'a> x = x +}. + +(* Polymorphic lemma using the parametric class *) +lemma act_idmE ['a <: addmonoid, 'b <: 'a action] (x : 'b) : + act idm<:'a> x = x. +proof. by apply act_id. qed. diff --git a/tests/tc/print.ec b/tests/tc/print.ec new file mode 100644 index 0000000000..8987ccd63b --- /dev/null +++ b/tests/tc/print.ec @@ -0,0 +1,18 @@ +require import AllCore. + +(* Regression: `print` must not crash on TC-related entities, and + abstract type printers must surface their TC constraints. *) +type class addmonoid = { + op idm : addmonoid + op (+) : addmonoid -> addmonoid -> addmonoid + + axiom addmA : associative (+) + axiom addmC : commutative (+) + axiom add0m : left_id idm (+) +}. + +type t <: addmonoid. + +print t. +print addmonoid. +print idm. diff --git a/tests/tc/section.ec b/tests/tc/section.ec new file mode 100644 index 0000000000..8475d4a923 --- /dev/null +++ b/tests/tc/section.ec @@ -0,0 +1,17 @@ +require import AllCore. + +(* A typeclass declared inside a section that survives section close *) +section. + type class my_monoid = { + op my_id : my_monoid + op my_op : my_monoid -> my_monoid -> my_monoid + + axiom my_left_id : forall (x : my_monoid), my_op my_id x = x + }. +end section. + +(* Reference the typeclass after the section *) +op double ['a <: my_monoid] (x : 'a) = my_op x x. + +lemma id_double ['a <: my_monoid] : double my_id<:'a> = my_id. +proof. rewrite /double my_left_id //. qed. diff --git a/tests/tc/smt.ec b/tests/tc/smt.ec new file mode 100644 index 0000000000..71b5e1dc75 --- /dev/null +++ b/tests/tc/smt.ec @@ -0,0 +1,82 @@ +require import AllCore. + +type class addmonoid = { + op idm : addmonoid + op (+) : addmonoid -> addmonoid -> addmonoid + + axiom addmA : associative (+) + axiom addmC : commutative (+) + axiom add0m : left_id idm (+) +}. + +(* 1) Concrete instance: SMT pre-reduction collapses TC ops, then smt() closes. *) +op zero_int : int = 0. +op plus_int : int -> int -> int = Int.( + ). + +instance addmonoid as int_inst with int + op idm = zero_int + op (+) = plus_int. + +realize addmA by rewrite /plus_int; smt(). +realize addmC by rewrite /plus_int; smt(). +realize add0m by rewrite /plus_int /zero_int; smt(). + +lemma idm_int : (idm<:int>) = zero_int by smt(). + +(* 2) Abstract carrier with TC axiom hints: SMT chains TC axioms through + the polymorphic operator surface. *) +lemma combine_abs ['a <: addmonoid] (x y : 'a) : (idm + x) + y = x + y. +proof. smt(add0m). qed. + +lemma triple_assoc ['a <: addmonoid] (x y z w : 'a) : + ((x + y) + z) + w = x + (y + (z + w)). +proof. smt(addmA). qed. + +(* 2bis) Abstract carrier WITHOUT explicit TC axiom hints: the TC axioms + tied to the tparam constraint are auto-included by [trans_tc_axioms]. *) +lemma idm_left_nohint ['a <: addmonoid] (x : 'a) : idm + x = x. +proof. smt(). qed. + +lemma idm_right_nohint ['a <: addmonoid] (x : 'a) : x + idm = x. +proof. smt(). qed. + +(* 3) TC inheritance: parent axioms remain available to SMT. *) +type class addgroup <: addmonoid = { + op opp : addgroup -> addgroup + axiom addNm : forall (x : addgroup), opp x + x = idm +}. + +lemma group_zero ['a <: addgroup] (x : 'a) : (opp x + x) + idm = idm. +proof. smt(addNm add0m). qed. + +(* 3bis) Inheritance + no-hints: parent (addmonoid) axioms must also be + pulled in via the ancestor walk. *) +lemma group_left_nohint ['a <: addgroup] (x : 'a) : idm + x = x. +proof. smt(). qed. + +lemma group_inv_nohint ['a <: addgroup] (x : 'a) : opp x + x = idm. +proof. smt(). qed. + +(* 4) Section [declare type t <: tc] reaches SMT correctly. *) +section. + declare type t <: addmonoid. + + lemma chain (a b c : t) : ((a + idm) + b) + (idm + c) = (a + b) + c. + proof. smt(add0m addmA addmC). qed. +end section. + +(* 5) Two distinct concrete instances coexist in one goal. *) +op zero_bool : bool = false. +op or_bool : bool -> bool -> bool = (\/). + +instance addmonoid as bool_inst with bool + op idm = zero_bool + op (+) = or_bool. + +realize addmA by rewrite /or_bool; smt(). +realize addmC by rewrite /or_bool; smt(). +realize add0m by rewrite /or_bool /zero_bool; smt(). + +lemma cross (i : int) (b : bool) : + zero_int + i = i /\ (zero_bool \/ b = false \/ b). +proof. smt(). qed.