"""A HuggingFace-style model configuration.""" from typing import Any, Dict, List from transformers import PretrainedConfig class CerberusDetConfig(PretrainedConfig): model_type = 'cerberus_v8x' def __init__( self, tasks_names: Dict[str, List[str]] = None, tasks_nc: List[int] = None, cfg: dict = None, cerber_schedule = None, stride: List[int] = None, config_name: str = None, agnostic_nms: bool=False, conf_thres: float = 0.3, iou_thres: float = 0.45, iou_thres_between_tasks: float = 0.8, **kwargs: Any ): super().__init__(**kwargs) # model configuration self.tasks_nc = tasks_nc self.tasks_names = tasks_names self.cfg = cfg self.cerber_schedule = cerber_schedule self.config_name = config_name self.strides = stride self.stride = max(self.strides) if self.strides is not None else None # model inference self.agnostic_nms = agnostic_nms self.conf_thres = conf_thres self.iou_thres = iou_thres self.iou_thres_between_tasks = iou_thres_between_tasks # additional if tasks_names: self.task_ids = [k for k, _ in tasks_names.items()] self.names: Dict[str, List[str]] = self.tasks_names # 1. Create category mapping (Task Name -> {Local ID -> Global ID} self.categories_inds_map, self.all_class_names = self._get_categories_map(self.names) # 2. Create Global Class ID -> Task Index map self.global_cls_to_task_id_map = {} for task_idx, task_name in enumerate(self.task_ids): if task_name in self.categories_inds_map: mapping = self.categories_inds_map[task_name] # global_class_id -> task_idx for global_id in mapping.values(): self.global_cls_to_task_id_map[int(global_id)] = task_idx else: self.global_cls_to_task_id_map = {} self.names = {} self.task_ids = [] self.all_class_names = [] self.categories_inds_map = {} def _get_categories_map(self, class_names: Dict[str, List[str]]): categories_inds_map: Dict[str, Dict[int, int]] = {} all_class_names: List[str] = [] tmp_categories_ids: List[List[int]] = [] for task_name, task_categories in class_names.items(): last_ind = tmp_categories_ids[-1][-1] + 1 if len(tmp_categories_ids) != 0 else 0 cur_categories_ids = list(range(len(task_categories))) tmp_categories_ids.append([ind + last_ind for ind in cur_categories_ids]) categories_inds_map[task_name] = { prev_id: new_id for prev_id, new_id in zip(cur_categories_ids, tmp_categories_ids[-1]) } all_class_names.extend(task_categories) return categories_inds_map, all_class_names