Library scope_extrusion

Require Import base.

Require Import nominal.
Require Import syntax.
Require Import basic_semantics.
Require Import expr_semantics.
Require Import par_semantics.
Require Import ctl_semantics.
Require Import var_semantics.
Require Import stmt_semantics.
Require Import config_stmt_semantics.

Lemma seq_var_extrusion1 : forall v p q
  (Hread_recp : read_receptive (psys p))
  (Hcrash: mho_obs_crashes (psys p)),

  ~support q v ->
  pacc
    (seq_prog (var_prog v p) q)
    (var_prog v (seq_prog p q)).

Lemma seq_var_extrusion2 : forall v p q
  (Hread_recp : read_receptive (psys q))
  (Hcrash: mho_obs_crashes (psys q)),

  ~support p v ->
  pacc
    (seq_prog p (var_prog v q))
    (var_prog v (seq_prog p q)).

Lemma ifte_var_extrusion1 : forall v e p q
  (Hread_recp : read_receptive (psys p))
  (Hcrash: mho_obs_crashes (psys p)),

  ~support e v ->
  ~support q v ->
  pacc
    (ifte_prog e (var_prog v p) q)
    (var_prog v (ifte_prog e p q)).

Lemma ifte_var_extrusion2 : forall v e p q
  (Hread_recp : read_receptive (psys q))
  (Hcrash: mho_obs_crashes (psys q)),

  ~support e v ->
  ~support p v ->
  pacc
    (ifte_prog e p (var_prog v q))
    (var_prog v (ifte_prog e p q)).

Lemma while_var_extrusion : forall v e p
  (Hread_recp : read_receptive (psys p))
  (Hcrash: mho_obs_crashes (psys p)),

  ~support e v ->
  pacc
    (while_prog e (var_prog v p))
    (var_prog v (while_prog e p)).

Lemma stmt_denote_read_receptive : forall s,
  read_receptive (psys (stmt_denote s)).

Lemma stmt_denote_mho_obs_crashes : forall s,
  mho_obs_crashes (psys (stmt_denote s)).

Fixpoint apply_vars (l:list VAR) (s:STMT) :=
  match l with
  | nil => s
  | v::l' => local v (apply_vars l' s)
  end.

Lemma apply_vars_support : forall l s v,
  support (apply_vars l s) v ->
  support s v /\ ~In v l.

Lemma apply_vars_perm : forall l s p,
  papp p (apply_vars l s) =
  apply_vars (papp p l) (papp p s).

Lemma apply_vars_app : forall(l1 l2:list VAR) (s:STMT),
  apply_vars (l1++l2) s =
  apply_vars l1 (apply_vars l2 s).

Lemma apply_vars_cong : forall l s1 s2,
  pacc (stmt_denote s1) (stmt_denote s2) ->
  pacc (stmt_denote (apply_vars l s1))
       (stmt_denote (apply_vars l s2)).

Lemma stmt_seq_var_extrusion1 : forall v p q,
  ~support q v ->
  pacc
    (stmt_denote (seq (local v p) q))
    (stmt_denote (local v (seq p q))).

Lemma stmt_seq_var_extrusion2 : forall v p q,
  ~support p v ->
  pacc
    (stmt_denote (seq p (local v q)))
    (stmt_denote (local v (seq p q))).

Lemma var_prog_pacc' : forall x p1 p2,
  pacc p1 p2 ->
  pacc (var_prog x p1) (var_prog x p2).

Lemma stmt_seq_vars_extrusion1 : forall l p q,
  (forall x, support q x -> In x l -> False) ->
  pacc
    (stmt_denote (seq (apply_vars l p) q))
    (stmt_denote (apply_vars l (seq p q))).

Lemma stmt_seq_vars_extrusion2 : forall l p q,
  (forall x, support p x -> In x l -> False) ->
  pacc
    (stmt_denote (seq p (apply_vars l q)))
    (stmt_denote (apply_vars l (seq p q))).

Lemma stmt_ifte_var_extrusion1 : forall v e p q,
  ~support e v ->
  ~support q v ->
  pacc
    (stmt_denote (ifte e (local v p) q))
    (stmt_denote (local v (ifte e p q))).

Lemma stmt_ifte_var_extrusion2 : forall v e p q,
  ~support e v ->
  ~support p v ->
  pacc
    (stmt_denote (ifte e p (local v q)))
    (stmt_denote (local v (ifte e p q))).

Lemma stmt_ifte_vars_extrusion1 : forall l e p q,
  (forall x, support e x -> In x l -> False) ->
  (forall x, support q x -> In x l -> False) ->
  pacc
    (stmt_denote (ifte e (apply_vars l p) q))
    (stmt_denote (apply_vars l (ifte e p q))).

Lemma stmt_ifte_vars_extrusion2 : forall l e p q,
  (forall x, support e x -> In x l -> False) ->
  (forall x, support p x -> In x l -> False) ->
  pacc
    (stmt_denote (ifte e p (apply_vars l q)))
    (stmt_denote (apply_vars l (ifte e p q))).

Lemma stmt_while_var_extrusion : forall v e p,
  ~support e v ->
  pacc
    (stmt_denote (while e (local v p)))
    (stmt_denote (local v (while e p))).

Lemma stmt_while_vars_extrusion : forall l e p,
  (forall x, support e x -> In x l -> False) ->
  pacc
    (stmt_denote (while e (apply_vars l p)))
    (stmt_denote (apply_vars l (while e p))).

Fixpoint stmt_idents (s:STMT) : list VAR :=
  match s with
  | expr e => expr_idents e
  | seq s1 s2 => stmt_idents s1 ++ stmt_idents s2
  | local v s => v :: stmt_idents s
  | ifte e s1 s2 => expr_idents e ++ stmt_idents s1 ++ stmt_idents s2
  | while e s => expr_idents e ++ stmt_idents s
  end.

Lemma stmt_idents_incl : forall s x,
  In x (stmt_free_idents s) -> In x (stmt_idents s).

Lemma stmt_idents_papp : forall s p,
  stmt_idents (papp p s) = papp p (stmt_idents s).

Inductive extr_state :=
  | extr : VAR -> list VAR -> STMT -> extr_state.

Fixpoint do_scope_extrusion (v0:VAR) (s:STMT) : extr_state :=
  match s with
  | expr e => extr v0 nil (expr e)

  | seq s1 s2 =>
      let (v1, l1, s1') := do_scope_extrusion v0 s1 in
      let (v2, l2, s2') := do_scope_extrusion v1 s2 in
        extr v2 (l1++l2) (seq s1' s2')

  | local v s =>
      let (v1, l1, s') := do_scope_extrusion v0 s in
        extr (S v1) (v1::l1) (papp (perm_swap v v1) s')

  | ifte e s1 s2 =>
      let (v1, l1, s1') := do_scope_extrusion v0 s1 in
      let (v2, l2, s2') := do_scope_extrusion v1 s2 in
        extr v2 (l1++l2) (ifte e s1' s2')

  | while e s =>
      let (v1, l1, s') := do_scope_extrusion v0 s in
        extr v1 l1 (while e s')
  end.

Fixpoint no_locals (s:STMT) :=
  match s with
  | expr _ => True
  | seq s1 s2 => no_locals s1 /\ no_locals s2
  | local _ _ => False
  | ifte _ s1 s2 => no_locals s1 /\ no_locals s2
  | while _ s => no_locals s
  end.

Lemma no_locals_papp : forall s p,
  no_locals s -> no_locals (papp p s).

Lemma do_scope_extrusion_correctness : forall s v0,
  (forall v, In v (stmt_idents s) -> v0 > v) ->

  let (v', l, s') := do_scope_extrusion v0 s in

  v0 <= v' /\
  (forall x, In x l -> v0 <= x < v') /\
  (forall x, In x (stmt_idents s') -> In x l \/ In x (stmt_idents s)) /\
  no_locals s' /\
  pacc
    (stmt_denote s)
    (stmt_denote (apply_vars l s')).

Definition scope_extrusion (s:STMT) : list VAR * STMT :=
  let (_, l, s') := do_scope_extrusion (S (list_max (stmt_idents s))) s
  in (l,s').

Theorem scope_extrusion_correctness : forall s,
  let (l,s') := scope_extrusion s in
  no_locals s' /\
  pacc (stmt_denote s) (stmt_denote (apply_vars l s')).