Skip to content
Snippets Groups Projects
Commit 7d550759 authored by Christian's avatar Christian
Browse files

Added live_wire implementation

parent daa7e7e0
Branches
No related tags found
1 merge request!3Live wire to be implemented into GUI
This commit is part of merge request !3. Comments created here will be created in the context of that merge request.
live_wire.py 0 → 100644
+ 185
0
View file @ 7d550759
import time
import cv2
import numpy as np
import heapq
import matplotlib.pyplot as plt
from scipy.ndimage import convolve
from skimage.filters import gaussian
from skimage.feature import canny
#### Helper functions ####
def neighbors_8(x, y, width, height):
"""Return the 8-connected neighbors of (x, y)."""
for nx in (x-1, x, x+1):
for ny in (y-1, y, y+1):
if 0 <= nx < width and 0 <= ny < height:
if not (nx == x and ny == y):
yield nx, ny
def dijkstra(cost_img, seed):
"""
Dijkstra's algorithm on a 2D grid, using cost_img as the per-pixel cost.
Args:
cost_img (np.array): 2D array of costs (float).
seed (tuple): (x, y) starting coordinate.
Returns:
dist (np.float32): array of minimal cumulative cost from seed to each pixel.
parent (np.int32): array storing predecessor of each pixel for path reconstruction.
"""
height, width = cost_img.shape
# Initialize dist and parent
dist = np.full((height, width), np.inf, dtype=np.float32)
dist[seed[1], seed[0]] = 0.0
parent = -1 * np.ones((height, width, 2), dtype=np.int32)
visited = np.zeros((height, width), dtype=bool)
pq = [(0.0, seed[0], seed[1])] # (distance, x, y)
while pq:
curr_dist, cx, cy = heapq.heappop(pq)
if visited[cy, cx]:
continue
visited[cy, cx] = True
for nx, ny in neighbors_8(cx, cy, width, height):
if visited[ny, nx]:
continue
# We can take an average or sum—here, let's just sum the cost
move_cost = 0.5 * (cost_img[cy, cx] + cost_img[ny, nx])
ndist = curr_dist + move_cost
if ndist < dist[ny, nx]:
dist[ny, nx] = ndist
parent[ny, nx] = (cx, cy)
heapq.heappush(pq, (ndist, nx, ny))
return dist, parent
def backtrack_path(parent, start, end):
"""
Reconstruct path from 'end' back to 'start' using 'parent' array.
Args:
parent (np.array): shape (H, W, 2), storing (px, py) for each pixel.
start (tuple): (x, y) start coordinate.
end (tuple): (x, y) end coordinate.
Returns:
path (list of (x, y)): from start to end inclusive.
"""
path = []
current = end
while True:
path.append(current)
if current == start:
break
px, py = parent[current[1], current[0]]
current = (px, py)
path.reverse()
return path
def compute_cost(image, sigma=3.0, epsilon=1e-5):
smoothed_img = gaussian(image, sigma=sigma)
canny_img = canny(smoothed_img)
cost_img = 1 / (canny_img + epsilon)
return cost_img, canny_img
def load_image(path, type):
# Load image
if type == 'gray':
img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
if img is None:
raise FileNotFoundError(f"Could not read {path}")
elif type == 'color':
img = cv2.imread(path, cv2.IMREAD_COLOR)
if img is None:
raise FileNotFoundError(f"Could not read {path}")
else:
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
else:
raise ValueError("type must be 'gray' or 'color'")
return img
def downscale(img, points, scale_percent):
if scale_percent == 100:
return img, (tuple(points[0]), tuple(points[1]))
else:
width = int(img.shape[1] * scale_percent / 100)
height = int(img.shape[0] * scale_percent / 100)
new_dimensions = (width, height)
# Downsample the image
downsampled_img = cv2.resize(img, new_dimensions, interpolation=cv2.INTER_AREA)
### SCALE POINTS
# Original image dimensions
original_width = img.shape[1]
original_height = img.shape[0]
# Downsampled image dimensions
downsampled_width = width
downsampled_height = height
# Scaling factors
scale_x = downsampled_width / original_width
scale_y = downsampled_height / original_height
# Original points
seed_xy = tuple(points[0])
target_xy = tuple(points[1])
# Scale the points
scaled_seed_xy = (int(seed_xy[0] * scale_x), int(seed_xy[1] * scale_y))
scaled_target_xy = (int(target_xy[0] * scale_x), int(target_xy[1] * scale_y))
return downsampled_img, (scaled_seed_xy, scaled_target_xy)
# Define the following
image_path = './tests/slice_60_volQ.png'
image_type = 'gray' # 'gray' or 'color'
downscale_factor = 100 # % of original size wanted
points_path = './tests/LiveWireEndPoints.npy'
# Load image
image = load_image(image_path, image_type)
# Load points
points = np.int0(np.round(np.load(points_path)))
# Downscale image and points
scaled_image, scaled_points = downscale(image, points, downscale_factor)
seed, target = scaled_points
# Compute cost image
cost_image, canny_img = compute_cost(scaled_image)
# Find path and time it
start_time = time.time()
dist, parent = dijkstra(cost_image, seed)
path = backtrack_path(parent, seed, target)
end_time = time.time()
print(f"Elapsed time for pathfinding: {end_time - start_time:.3f} seconds")
color_img = cv2.cvtColor(scaled_image, cv2.COLOR_GRAY2BGR)
for (x, y) in path:
color_img[y, x] = (0, 0, 255) # red (color of path)
plt.figure(figsize=(20,8))
plt.subplot(1,2,1)
plt.title("Cost Image")
plt.imshow(canny_img, cmap='gray')
plt.subplot(1,2,2)
plt.title("Path from Seed to Target")
plt.imshow(color_img[..., ::-1]) # BGR->RGB for plotting
plt.show()
File added
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment