Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions multimodal/vl2l/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ dependencies = [
"scikit-learn",
"tabulate",
"hiclass",
"rapidfuzz",
]
dynamic = ["version"]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import subprocess
import time
from abc import ABC, abstractmethod
from datetime import timedelta
from datetime import timedelta # noqa: TC003
from typing import TYPE_CHECKING, Self
from urllib.parse import urlparse

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from hiclass.metrics import f1 # type: ignore[import-untyped]
from loguru import logger
from pydantic import ValidationError
from rapidfuzz import fuzz
from sklearn.metrics import f1_score # type: ignore[import-untyped]
from tabulate import tabulate

Expand Down Expand Up @@ -107,22 +108,40 @@ def calculate_hierarchical_f1(data: list[tuple[str, str]]) -> float:
return 0.0 if hp + hr == 0 else 2 * (hp * hr) / (hp + hr)


def calculate_exact_match(generated_text: str, original_text: str) -> float:
"""Calculates binary Exact Match (EM) score.

We clean the text (lowercase, strip whitespace) for a fairer comparison.
def calculate_brand_f1_score(data: list[tuple[str, str]]) -> float:
"""Calculate the F1 score of brand field.

Args:
generated_text: Output from the VLM.
original_text: Ground truth information from the dataset.
data: A list of tuples, where each tuple is
(predicted_path_str, true_path_str).

Returns:
1 if the values match or 0 otherwise
F1 score
"""
gen = generated_text.strip().lower()
orig = original_text.strip().lower()
valid_threshold = 90
matches = []
for pred, src in data:
norm_truth = src.strip().lower()
norm_pred = pred.strip().lower()

return 1.0 if gen == orig else 0.0
# Exact Match
if norm_truth == norm_pred:
matches.append(1)
continue
Comment on lines +128 to +130
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm just wondering, if it's an exact match, wouldn't the score from fuzz.ratio be also bigger than valid_threshold? Therefore, maybe there's no need to treat the exact match as a special case (that needed to be handled differently)?


# Fuzzy Match (Handles typos like "Adodas")
# fuzz.ratio calculates edit distance similarity (0-100)
score = fuzz.ratio(norm_truth, norm_pred)

# Threshold: If > 90/100 similarity, count as correct
if score > valid_threshold:
matches.append(1)
else:
matches.append(0)

# Calculate the Score
# For 1-to-1 extraction, Accuracy = Recall = Micro F1
return sum(matches) / len(matches)


def calculate_secondhand_f1(data: list[tuple[bool, bool]]) -> float:
Expand Down Expand Up @@ -197,6 +216,7 @@ def run_evaluation(filename: FilePath, dataset: DatasetCLI) -> None:
category_rand_pred_src = []
is_secondhand_pred_src = []
is_secondhand_rand_pred_src = []
brand_pred_src = []

for elem in model_output:
idx = elem["qsl_idx"]
Expand Down Expand Up @@ -233,9 +253,13 @@ def run_evaluation(filename: FilePath, dataset: DatasetCLI) -> None:
is_secondhand_rand_pred_src.append((rand_is_secondhand,
ground_truth_item["ground_truth_is_secondhand"]))

brand_pred_src.append((pred_item.brand,
ground_truth_item["ground_truth_brand"]))

category_f1_score = calculate_hierarchical_f1(category_dataset_pred_src)
hiclass_f1_score = calculate_hiclass_f1(category_dataset_pred_src)
is_secondhand_f1_score = calculate_secondhand_f1(is_secondhand_pred_src)
brand_score = calculate_brand_f1_score(brand_pred_src)

rand_cat_f1_score = calculate_hierarchical_f1(category_rand_pred_src)
rand_hiclass_f1_score = calculate_hierarchical_f1(category_rand_pred_src)
Expand All @@ -244,9 +268,10 @@ def run_evaluation(filename: FilePath, dataset: DatasetCLI) -> None:

data = [
["category", category_f1_score, hiclass_f1_score,
rand_cat_f1_score, rand_hiclass_f1_score],
rand_cat_f1_score, rand_hiclass_f1_score, 0],
["is_secondhand", is_secondhand_f1_score, 0,
rand_is_seconhand_f1_score, 0],
rand_is_seconhand_f1_score, 0, 0],
["brand", 0, 0, 0, 0, brand_score],
]

logger.info(
Expand All @@ -256,7 +281,8 @@ def run_evaluation(filename: FilePath, dataset: DatasetCLI) -> None:
headers=["Fields", "F1 Score",
"HiClass F1 Score",
"F1 Score Random Selection",
"HiClass F1 Score Random Selection"],
"HiClass F1 Score Random Selection",
"Brand F1 Score"],
tablefmt="fancy_grid",
),
)
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,9 @@ class Dataset(BaseModelWithAttributeDescriptionsFromDocstrings):
token: str | None = None
"""The token to access the HuggingFace repository of the dataset."""

revision: str | None = None
"""The revision of the dataset."""

split: list[str] = ["train", "test"]
"""Dataset splits to use for the benchmark, e.g., "train" and "test". You can add
multiple splits by repeating the same CLI flag multiple times, e.g.:
Expand Down Expand Up @@ -445,8 +448,8 @@ class ProductMetadata(BaseModelWithAttributeDescriptionsFromDocstrings):
Each categorical level is separated by " > ".
"""

brands: list[str]
"""The brands of the product, e.g., ["giorgio armani", "hugo boss"]."""
brand: str
"""The brand of the product, e.g., "giorgio armani"."""

is_secondhand: bool
"""True if the product is second-hand, False otherwise."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def __init__(
self.dataset = load_dataset(
dataset.repo_id,
token=dataset.token,
revision=dataset.revision,
split="+".join(dataset.split),
)
logger.debug(
Expand Down Expand Up @@ -498,7 +499,7 @@ def formulate_loaded_sample(self, sample: dict[str, Any]) -> LoadedSample:
"content": f"""Please analyze the product from the user prompt
and provide the following fields in a valid JSON object:
- category
- brands
- brand
- is_secondhand

You must choose only one, which is the most appropriate, correct, and specifc
Expand Down