// Copyright (c) Microsoft Corporation 2005-2006.
// This sample code is provided "as is" without warranty of any kind. 
// We disclaim all warranties, either express or implied, including the 
// warranties of merchantability and fitness for a particular purpose. 

#light

module MyLibrary.Expr

//---------------------------------------------------------------------------
// This sample defines a symbolic differentiator.
//
// Multi-variate expressions, extensible with new unary operators, as long
// as each provides its semantics, i.e. evaluation function, Name and 
// derivative.  The derivative of each unary operator is expressed 
// as a function value derivative (effectively a virtual method - each 
// different unary operator can provide a different implementation).  It could
// also be expressed as a formula in terms of some fixed variable, and then
// substitution would be used to apply the specification of the derivative.
// The Print function accepts a buffer and provides a helper function to 
// print the operand of the unary operator.
//
// Sample Task: add some further unary operators.
//
// Sample Task: modify the approach to use substitution.
//
// Sample Task: Modify the program to include operators of many variables, such
// as functions of the form
//    +  ... K30 x^3     + K20 x^2     + K10 x     + K00
//    +  ... K31 x^3 y   + K21 x^2 y   + K11 x y   + K01 y
//    +  ... K31 x^3 y^2 + K21 x^2 y^2 + K11 x y^2 + K01 y^2
// whose operands are specified by a N-dimensional matrix.
//
// Sample exercise: Adjust to use Taylor's series approximations, where 
// evaluation is a computation over lazy data structures that evaluates 
// until an appropriate error-bound is reached, and differentiation is
// a morph from one lazy data structure to another.  Expressions would
// be Taylor's series of single variables, represented as a stream of 
// coefficients.  The Stream module from the ML compatibility 
// library will come in handy.
//---------------------------------------------------------------------------

type Expr = 
    | Sum of Expr * Expr
    | Prod of Expr * Expr
    | Var of string
    | Const of float
    | Unary of UnaryOperator * Expr
  
and UnaryOperator = 
    { Evaluate : float -> float; // evaluation function
      Differentiate: Expr -> Expr;   // symbolic differentiation function
      Print: Buffer.t -> (Buffer.t -> unit) -> unit;
      Name: string } // identity

//---------------------------------------------------------------------------
// Evaluation given an environment
//
// Debugging demo cheat sheet: set a breakpoint here and then run.
// Then examine the contents of "expr", looking to see which UnaryOperator 
// it is
//---------------------------------------------------------------------------

let rec evaluate env expr = 
    match expr with 
    | Sum (e1,e2) -> evaluate env e1 + evaluate env e2
    | Prod (e1,e2) -> evaluate env e1 * evaluate env e2
    | Var v -> env(v)
    | Const n -> n
    | Unary (oper,e) -> oper.Evaluate (evaluate env e)

//---------------------------------------------------------------------------
// Differentiation w.r.t. a single variable
//---------------------------------------------------------------------------

let rec derivative var expr = 
    match expr with 
    | Sum (e1,e2) -> Sum(derivative var e1, derivative var e2)
    | Prod (e1,e2) -> Sum(Prod(derivative var e1, e2), Prod(e1,derivative var e2))
    | Var x -> if var = x then Const(1.0) else Const(0.0)
    | Const(n) -> Const(0.0)
    | Unary (oper, e) -> Prod(oper.Differentiate e, derivative var e) // chain-rule!

//---------------------------------------------------------------------------
// Some unary functions and their derivatives
//---------------------------------------------------------------------------
     
let negop = 
    { Evaluate=(fun x -> - x); 
      Differentiate=(fun e -> Const(-1.0)); 
      Print=(fun buf ef -> Printf.bprintf buf "-(%t)" ef);
      Name="neg" }
let Neg(e) = Unary(negop,e)
             
let rec sinop = 
    { Evaluate=sin; 
      Differentiate=(fun e -> Unary(cosop,e)); 
      Print=(fun buf ef -> Printf.bprintf buf "sin(%t)" ef);
      Name="sin" }
and cosop = 
    { Evaluate=cos; 
      Differentiate=(fun e -> Neg(Unary(sinop, e))); 
      Print=(fun buf ef -> Printf.bprintf buf "cos(%t)" ef);
      Name="cos" }

let Sin(e) = Unary(sinop,e)
let Cos(e) = Unary(cosop,e)

// e^x
let rec expop = 
    { Evaluate= exp; 
      Differentiate=(fun e -> Unary(expop,e)); 
      Print=(fun buf ef -> Printf.bprintf buf "exp(%t)" ef);
      Name="exp" }
let Exp(e) = Unary(expop,e)
 
// x^n
let rec pow n x = if n = 0 then 1.0 else x * pow (n-1) x 
let rec Pow n e = 
    if n < 0 then failwith "invalid_arg";
    if n = 0 then Const 1.0 
    else if n = 1 then e
    else Unary(powop n, e)

and powop n = 
    { Evaluate=(fun x -> pow n x);
      Differentiate=(fun e -> Prod(Const(float n), Pow (n-1) e)); 
      Print=(fun buf ef -> Printf.bprintf buf "pow%d(%t)" n ef);
      Name="pow"^string_of_int n }

// Polynomials: list gives coefficients in reverse order, e.g. [2;3;4] means 4x^2 + 3x + 2
let rec Poly l e = 
    let rec mk n l = 
        match l with 
        | [] -> Const 0.0 
        | h :: t -> Sum(Prod (Const h, Pow n e),mk (n+1) t) in 
    mk 0 l
let Sq(e) = Pow 2 e
let Cube(e) = Pow 3 e

// 1/x
let rec invop = 
    { Evaluate=(fun x -> if x = 0.0 then 0.0 else 1.0 / x); 
      Differentiate=(fun e -> Neg(Unary(invop,Sq(e)))); 
      Print=(fun buf ef -> Printf.bprintf buf "inv(%t)" ef);
      Name="inv" }
      
let Inv(e) = Unary(invop, e)

// log(x)
let logop = 
    { Evaluate=(fun x -> if x < 0.0 then 0.0 else log x); 
      Differentiate=(fun e -> Inv(e)); 
      Print=(fun buf ef -> Printf.bprintf buf "log(%t)" ef);
      Name="log" }

let Log(e) = Unary(logop,e)
 
//---------------------------------------------------------------------------
// Optimize constants away.  Bottom-up rewriting.  For each expression
// we calculate a summary of what we learnt about it.
//---------------------------------------------------------------------------

type summary = 
    | Known of float 
    | Unknown
  
  
let known c = Const(c), Known(c)
let rec Optimize(expr) = 
    match expr with 
    | Sum (e1,e2) -> 
         let e1new,e1info = Optimize e1 in 
         let e2new,e2info = Optimize e2 in 
         match e1info,e2info with 
         | Known x, Known y -> known(x + y)
         | Known 0.0, _ -> e2new , e2info
         | _, Known 0.0 -> e1new, e1info
         | _ -> Sum(e1new,e2new), Unknown
         
    | Prod (e1,e2) -> 
         let e1new,e1info = Optimize e1 in 
         let e2new,e2info = Optimize e2 in 
         match e1info,e2info with 
         | Known x, Known y -> known(x * y)
         | Known 0.0, _
         | _, Known 0.0 -> known 0.0 
         | Known 1.0, _ -> e2new, e2info
         | _, Known 1.0 -> e1new, e1info
         | _ -> Prod(e1new,e2new), Unknown
         
    | Var v -> Var v, Unknown
    | Const n -> Const n, Known n
    | Unary (oper, e1) ->
         let e1new,e1info = Optimize e1 in 
         match e1info with 
         | Known x -> known(oper.Evaluate x)
         | _ -> Unary(oper,e1new), Unknown
       
//---------------------------------------------------------------------------
// Print expressions to buffers (StringBuilders)
//---------------------------------------------------------------------------

let rec printExpr buf expr = 
    match expr with 
    | Sum (e1,e2) -> Printf.bprintf buf "(%a + %a)" printExpr e1 printExpr e2
    | Prod (e1,e2) -> Printf.bprintf buf "%a * %a" printExpr e1 printExpr e2
    | Var v -> Printf.bprintf buf "%s" v
    | Const n -> Printf.bprintf buf "%3.2f" n
    | Unary (oper,e1) -> oper.Print buf (fun b -> printExpr b e1)

let exprToString e = 
    let buf = Buffer.create 100 in 
    printExpr buf (fst (Optimize e));
    Buffer.contents(buf)
  
//---------------------------------------------------------------------------
// Derivatives w.r.t. "x"
//---------------------------------------------------------------------------

let X = Var "x"
let Ex expr x = evaluate (fun v -> if v = "x" then x else 0.0) expr
let Dx expr = derivative "x" expr

     
