Return iteration count
This commit is contained in:
parent
026519697b
commit
0c8527704d
@ -9,7 +9,7 @@ def find_sols(cost_array, jacobian, step_size=0.001, max_iterations=5, initial=(
|
|||||||
while iterations < max_iterations:
|
while iterations < max_iterations:
|
||||||
curr_cost = numpy.array(cost_array(current))
|
curr_cost = numpy.array(cost_array(current))
|
||||||
if curr_cost.dot(curr_cost) < desired_cost_squared:
|
if curr_cost.dot(curr_cost) < desired_cost_squared:
|
||||||
return ("Finished early", current)
|
return ("Finished early", iterations, current)
|
||||||
gradient = .5 * numpy.matmul(numpy.transpose(jacobian(current)), curr_cost)
|
gradient = .5 * numpy.matmul(numpy.transpose(jacobian(current)), curr_cost)
|
||||||
next = current - step_size * gradient
|
next = current - step_size * gradient
|
||||||
next_cost = numpy.array(cost_array(next))
|
next_cost = numpy.array(cost_array(next))
|
||||||
@ -22,4 +22,4 @@ def find_sols(cost_array, jacobian, step_size=0.001, max_iterations=5, initial=(
|
|||||||
tries -= 1
|
tries -= 1
|
||||||
current = next
|
current = next
|
||||||
iterations += 1
|
iterations += 1
|
||||||
return ("Ran out of iterations", current)
|
return ("Ran out of iterations", iterations, current)
|
||||||
|
Reference in New Issue
Block a user