S0 <- 100
r <- 0.05
sigma <- 0.5
lamb <- 1
T <- 1
N <- 100

#calculate probabilities and up/down factors
dt <- T / N
dz <- sqrt(dt)/500
u <- exp(sigma * sqrt(dt))
d <- exp(-sigma * sqrt(dt))
exp_r_dt <- exp(r * dt)
p1 <- (exp_r_dt - d) / (u - d)
x1 <- u
x2 <- d
p1_hat <- p1 * x1/exp_r_dt
mu <-  p1_hat*(1/x1)+(1-p1_hat)*(1/x2)

#initial values for upper and lower cut-off points
zL <- numeric(N + 1)
zU <- numeric(N + 1)
zL[1] <- 1-dt
zU[1] <- 1+dt

#cut-off points

for (j in 2:(N+1)) {
  zL[j] <- ((j-1)/j) * zL[j - 1]/x1 + 1/j
}

for (j in 2:(N+1)) {
  zU[j] <- ((j-1)/j) * zU[j - 1]/x2 + 1/j
}

#maximum grid size
mold <- 0
for (j in 1:(N+1)) {
  mold <- max(mold, floor((zU[j] - zL[j])/dz))
}

# Initialize grids
gridold <- seq(zL[N+1], by = dz, length.out = mold + 1)
opt_old <- pmax(gridold - lamb, 0)

############################################
#IN CASE OF CALL OPTION
#opt_old <- pmax(lamb - gridold, 0)
############################################

for (j in N:1) {
  m <- mold
  gridnew <- seq(zL[j], by = dz, length.out = m + 1)
  opt_new <- numeric(m + 1)
  
  
  j_j1 <- j/(j+1)
  inv_j1 <- 1/(j+1)
  
  for (i in 1:(m+1)) {
    z <- gridnew[i]
    
    if (z > zU[j]) {
      opt_new[i] <- z*mu + dt
      #American style
      # continuation_value <- z*mu + dt
      #opt_new[i] <- max(continuation_value, z - lamb)
    } else if (z < zL[j]) {
      opt_new[i] <- 0
    } else {
      #Compute mx1 and mx2
      term1 <- j_j1 * z/x1 + inv_j1
      term2 <- j_j1 * z/x2 + inv_j1
      mx1 <- floor((term1 - zL[j+1])/dz)
      mx2 <- floor((term2 - zL[j+1])/dz)
      
      
      if (mx1 > m) {
        c_1 <- z*mu + dt
      } else if (mx1 > m-1) {
        value1 <- opt_old[mx1+1]
        value2 <- z*mu + dt
        alpha1 <- (zL[j+1] + (mx1+1)*dz - term1)/dz
        c_1 <- alpha1*value1 + (1-alpha1)*value2
      } else if (mx1 < -1) {
        c_1 <- 0
      } else {
        value1 <- ifelse(mx1 < 0, 0, opt_old[mx1+1])
        value2 <- opt_old[mx1+2]
        alpha1 <- (zL[j+1] + (mx1+1)*dz - term1)/dz
        alpha1 <- pmin(pmax(alpha1, 0), 1)
        c_1 <- alpha1*value1 + (1-alpha1)*value2
      }
      
      
      if (mx2 > m) {
        c_2 <- z*mu + dt
      } else if (mx2 > m-1) {
        value3 <- opt_old[mx2+1]
        value4 <- z*mu + dt
        alpha2 <- (zL[j+1] + (mx2+1)*dz - term2)/dz
        c_2 <- alpha2*value3 + (1-alpha2)*value4
      } else if (mx2 < -1) {
        c_2 <- 0
      } else {
        value3 <- ifelse(mx2 < 0, 0, opt_old[mx2+1])
        value4 <- opt_old[mx2+2]
        alpha2 <- (zL[j+1] + (mx2+1)*dz - term2)/dz
        alpha2 <- pmin(pmax(alpha2, 0), 1)
        c_2 <- alpha2*value3 + (1-alpha2)*value4
      }
      
      opt_new[i] <- p1_hat*c_1 + (1-p1_hat)*c_2
      ###########################################
      #For American style
      #continuation_value <- p1_hat*c_1 + (1-p1_hat)*c_2
      #opt_new[i] <- max(continuation_value, z - lamb)
    }
  }
  
  #update for next iteration
  opt_old <- opt_new
  gridold <- gridnew
  mold <- m
}

#interpolation
x_val <- 1
mxj <- floor((x_val - zL[1]) / dz)
alpha <- (zL[1] + (mxj+1)*dz - x_val) / dz
interp_val <- alpha * opt_new[mxj+2] + (1 - alpha) * opt_new[mxj+3]

print(S0 * interp_val) # price

