#!/usr/bin/env python3
import numpy as np
import matplotlib.pyplot as plt
def cost(x, y):
return 2*x**2 + 2*y**2 + 1
fig = plt.figure()
ax = plt.axes(projection='3d')