#!/usr/bin/env python3 import numpy as np import matplotlib.pyplot as plt def cost(x, y): return 2*x**2 + 2*y**2 + 1 fig = plt.figure() ax = plt.axes(projection='3d')