Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 13 additions & 29 deletions src/deepforest/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def label_to_color(label: int) -> tuple:


def convert_to_sv_format(
df: pd.DataFrame, width: int | None = None, height: int | None = None
df: pd.DataFrame, image: np.typing.NDArray | None = None
) -> sv.Detections | sv.KeyPoints:
"""Convert DeepForest prediction results into a supervision object.

Expand Down Expand Up @@ -198,20 +198,18 @@ def convert_to_sv_format(
# Create a reverse mapping from integer to string labels
class_name = {v: k for k, v in label_mapping.items()}

# Auto-detect width/height if missing
if width is None or height is None:
# Determine image dimensions
if image is not None:
height, width = image.shape[:2]
else:
if "image_path" not in df.columns:
raise ValueError("'image_path' column required for polygons.")

# Use the first image_path entry
image_path = df["image_path"].iloc[0]
try:
with Image.open(image_path) as img:
width, height = img.size # Get dimensions
except Exception as e:
raise ValueError(
f"Could not read image dimensions from {image_path}: {e}"
) from e
full_path = os.path.join(df.root_dir, image_path)

with Image.open(full_path) as img:
width, height = img.size

polygons = df.geometry.apply(lambda x: np.array(x.exterior.coords)).values
# as integers
Expand Down Expand Up @@ -297,8 +295,6 @@ def __check_color__(
def plot_annotations(
annotations: pd.DataFrame,
savedir: str | None = None,
height: int | None = None,
width: int | None = None,
color: list | sv.ColorPalette | None = None,
thickness: int = 2,
basename: str | None = None,
Expand All @@ -312,8 +308,7 @@ def plot_annotations(
Args:
annotations: DataFrame with annotations
savedir: Directory to save plot
height: Image height in pixels
width: Image width in pixels

color: Color for annotations
thickness: Line thickness
basename: Base name for saved file
Expand All @@ -340,8 +335,6 @@ def plot_annotations(
df=annotations,
image=image,
sv_color=annotation_color,
height=height,
width=width,
thickness=thickness,
radius=radius,
)
Expand Down Expand Up @@ -369,8 +362,6 @@ def plot_results(
results: pd.DataFrame,
ground_truth: pd.DataFrame | None = None,
savedir: str | None = None,
height: int | None = None,
width: int | None = None,
results_color: list | sv.ColorPalette | None = None,
ground_truth_color: list | sv.ColorPalette | None = None,
thickness: int = 2,
Expand All @@ -389,8 +380,7 @@ def plot_results(
results: Pandas DataFrame of prediction results.
ground_truth: Optional DataFrame of ground-truth annotations.
savedir: Optional path to save the figure; if None, plots interactively.
height: Image height in pixels. Required when using polygon geometry.
width: Image width in pixels. Required when using polygon geometry.

results_color: Single RGB list (e.g., [245, 135, 66]) or an sv.ColorPalette for per-label colors.
ground_truth_color: Single RGB list (e.g., [0, 165, 255]) or an sv.ColorPalette.
thickness: Line thickness in pixels.
Expand Down Expand Up @@ -420,8 +410,6 @@ def plot_results(
df=results,
image=image,
sv_color=results_color_sv,
height=height,
width=width,
thickness=thickness,
radius=radius,
)
Expand All @@ -432,8 +420,6 @@ def plot_results(
df=ground_truth,
image=annotated_scene,
sv_color=ground_truth_color_sv,
height=height,
width=width,
thickness=thickness,
radius=radius,
)
Expand All @@ -460,9 +446,7 @@ def plot_results(
return fig


def _plot_image_with_geometry(
df, image, sv_color, thickness=1, radius=3, height=None, width=None
):
def _plot_image_with_geometry(df, image, sv_color, thickness=1, radius=3):
"""Annotates an image with the given results.

Args:
Expand All @@ -476,7 +460,7 @@ def _plot_image_with_geometry(
"""
# Determine the geometry type
geom_type = determine_geometry_type(df)
detections = convert_to_sv_format(df, height=height, width=width)
detections = convert_to_sv_format(df, image=image)

if geom_type == "box":
bounding_box_annotator = sv.BoxAnnotator(color=sv_color, thickness=thickness)
Expand Down
28 changes: 21 additions & 7 deletions tests/test_visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,16 +149,30 @@ def test_plot_results_point_no_label(tmpdir):


def test_plot_results_polygon(gdf_poly, tmpdir):
# Call the function without height/width
visualize.plot_results(gdf_poly, savedir=tmpdir)

# Read in image and get height
image = cv2.imread(get_data("OSBS_029.tif"))
height = image.shape[0]
width = image.shape[1]
# Assertions
assert os.path.exists(os.path.join(tmpdir, "OSBS_029.png"))

# Call the function
visualize.plot_results(gdf_poly, savedir=tmpdir,height=height, width=width)

# Assertions
def test_plot_with_relative_paths(tmpdir):
# Test that plot_results and plot_annotations work with relative paths and root_dir
full_path = get_data("OSBS_029.png")
relative_name = os.path.basename(full_path)
root_dir = os.path.dirname(full_path)

data = {
'geometry': [geometry.Polygon([(10, 10), (20, 10), (20, 20), (10, 20), (15, 25)])],
'label': ['Tree'],
'image_path': [relative_name],
'score': [0.9]
}
gdf = gpd.GeoDataFrame(data)
gdf.root_dir = root_dir

visualize.plot_results(gdf, savedir=tmpdir, show=False)
visualize.plot_annotations(gdf, savedir=tmpdir, show=False)
assert os.path.exists(os.path.join(tmpdir, "OSBS_029.png"))


Expand Down
Loading