diff --git a/pathfinder/gradient_descent.py b/pathfinder/gradient_descent.py index 446cc79..4718a25 100644 --- a/pathfinder/gradient_descent.py +++ b/pathfinder/gradient_descent.py @@ -9,7 +9,7 @@ def find_sols(cost_array, jacobian, step_size=0.001, max_iterations=5, initial=( while iterations < max_iterations: curr_cost = numpy.array(cost_array(current)) if curr_cost.dot(curr_cost) < desired_cost_squared: - return ("Finished early", current) + return ("Finished early", iterations, current) gradient = .5 * numpy.matmul(numpy.transpose(jacobian(current)), curr_cost) next = current - step_size * gradient next_cost = numpy.array(cost_array(next)) @@ -22,4 +22,4 @@ def find_sols(cost_array, jacobian, step_size=0.001, max_iterations=5, initial=( tries -= 1 current = next iterations += 1 - return ("Ran out of iterations", current) + return ("Ran out of iterations", iterations, current)