

S0 <- 100      # Initial stock price
K <- 90       # Strike price
r <- 0.09     # Risk-free rate
sigma <- 0.5  # Volatility
T <- 1        # Time to maturity
N <- 30


dt <- T / N
dz <- sqrt(dt)/300
u <- exp(sigma * sqrt(dt))  #up
d <- exp(-sigma * sqrt(dt)) #down
p1 <- (exp(r * dt) - d) / (u - d)

x1 <- u  #up
x2 <- d  #down

p1_hat <- p1 * x1/exp(r*dt)
p2_hat <- 1 - p1_hat

zL <- numeric(N + 1) #zL[N+1] = 0 -> goes backward
zU <- numeric(N + 1) #zU[1] = 0 -> goes forward
zL[N + 1] <- 0  #Lower cut-off at maturity
zU[1] <- 0      #Upper cut-off at time 0

#lower cut-off points (backward recursion)
for (j in N:1) {
  zL[j] <- (zL[j + 1] - dt)*x1
}

#upper cut-off points (forward propagation)
for (j in 2:(N + 1)) {
  zU[j] <- zU[j - 1] /x2 + dt
}

mu <- p1_hat*(1/x1)+p2_hat*(1/x2)


#Grid points for z at each time step
z_grid <- vector("list", N+1)
nz <- c()
for (j in 1:(N+1)) {
  m <- floor((zU[j] - zL[j]) /dz) +1
  z_grid[[j]] <- zL[j] + (0:(m)) * dz 
  nz[j] <- m
}

#option price calculation using backward recursion
option_value <- vector("list", N+1)
for (j in 1:(N+1)) {
  option_value[[j]] <- numeric(length(z_grid[[j]]))
}


#calculate at N
option_value[[N+1]] <- pmax(z_grid[[N+1]], 0)


mjx1 <- vector("list", N+1)
mjx2 <- vector("list", N+1)
alphax1 <- vector("list", N+1)
alphax2 <- vector("list", N+1)


#calculate mjx and alphajx
for (j in 1:N) {
  
  mjx1[[j+1]] <- floor((z_grid[[j]]/x1 + dt - zL[j+1])/dz)
  mjx2[[j+1]] <- floor((z_grid[[j]]/x2 + dt - zL[j+1])/dz)
  alphax1[[j+1]] <- (zL[j+1] + (mjx1[[j+1]] +1)*dz - z_grid[[j]]/x1 - dt)/dz
  alphax2[[j+1]] <- (zL[j+1] + (mjx2[[j+1]] +1)*dz - z_grid[[j]]/x2 - dt)/dz
  
}


#version 1 (build in approximation)
#for (j in N:1) {
#  z11 <- zL[j+1] + mjx1[[j+1]]*dz
#  z12 <- zL[j+1] + (mjx1[[j+1]] + 1)*dz
#  z21 <- zL[j+1] + mjx2[[j+1]]*dz
#  z22 <- zL[j+1] + (mjx2[[j+1]] + 1)*dz
#  
#  zmjx1 <- z_grid[[j]]/x1 + dt
#  zmjx2 <- z_grid[[j]]/x2 + dt
#  
#  for (i in seq_along(z_grid[[j]])) {
#    if (z_grid[[j]][i] > zU[j]) {
#      option_value[[j]][i] <- z_grid[[j]][i] * mu + dt
#    
#      } else if (zmjx1[i] < z11[i] | zmjx1[i] > z12[i] | zmjx2[i] < z21[i] | zmjx2[i] > z22[i]) {
#      option_value[[j]][i] <- 0
#    
#      } else {
#      option_value[[j]][i] <- p1_hat * approx(z_grid[[j+1]], option_value[[j+1]], xout = zmjx1[i], rule = 2)$y +
#        p2_hat * approx(z_grid[[j+1]], option_value[[j+1]], xout = zmjx2[i], rule = 2)$y
#    }
#  }
#}


#version 2 (manual interpolation)
for (j in N:1) {
  z11 <- zL[j+1] + mjx1[[j+1]] * dz
  z12 <- zL[j+1] + (mjx1[[j+1]] + 1) * dz
  z21 <- zL[j+1] + mjx2[[j+1]] * dz
  z22 <- zL[j+1] + (mjx2[[j+1]] + 1) * dz
  
  zmjx1 <- z_grid[[j]] / x1 + dt
  zmjx2 <- z_grid[[j]] / x2 + dt
  
  for (i in seq_along(z_grid[[j]])) {
    if (z_grid[[j]][i] > zU[j]) {
      option_value[[j]][i] <- z_grid[[j]][i] * mu + dt
      next
    }
    
    idx_up1 <- mjx1[[j+1]][i] + 1
    idx_up2 <- idx_up1 + 1
    idx_dn1 <- mjx2[[j+1]][i] + 1
    idx_dn2 <- idx_dn1 + 1
    
    #if zmjx values fall within interpolation intervals
    if (zmjx1[i] < z11[i] || zmjx1[i] > z12[i] || zmjx2[i] < z21[i] || zmjx2[i] > z22[i]) {
      option_value[[j]][i] <- 0
      next
    }
    
    #to avoid out-of-bounds errors
    grid_len <- length(option_value[[j+1]])
    if (idx_up1 < 1 || idx_up2 > grid_len || idx_dn1 < 1 || idx_dn2 > grid_len) {
      option_value[[j]][i] <- 0
      next
    }
    
    #Interpolation
    v_up <- alphax1[[j+1]][i] * option_value[[j+1]][idx_up1] +
      (1 - alphax1[[j+1]][i]) * option_value[[j+1]][idx_up2]
    
    v_down <- alphax2[[j+1]][i] * option_value[[j+1]][idx_dn1] +
      (1 - alphax2[[j+1]][i]) * option_value[[j+1]][idx_dn2]
    
    option_value[[j]][i] <- p1_hat * v_up + p2_hat * v_down
  }
}


#price
#S0/(T+dt)*approx(z_grid[[1]], option_value[[1]], xout = dt-K/S0*(T+dt), rule = 2)$y


x_val <- dt - K/S0 * (T + dt)
mxj <- floor((x_val - zL[1]) / dz)
alpha <- (zL[1] + (mxj + 1)*dz - x_val) / dz
v1 <- option_value[[1]][mxj + 1]
v2 <- option_value[[1]][mxj + 2]
interp_val <- alpha * v1 + (1 - alpha) * v2
price <- S0 / (T + dt) * interp_val
cat("Price:", price)
#7.17783351

