diff --git a/panos/base.py b/panos/base.py index 0f3d9f77..5eb667c2 100644 --- a/panos/base.py +++ b/panos/base.py @@ -1604,13 +1604,19 @@ def _update_reference_in_objects( **kwargs ): update_needed = False - # Find any current references to self and remove them, unless it is the desired reference + + # Resolve target upfront for identity-based skip + target_obj = None + if reference_name is not None: + target_obj = parent.find_or_create(reference_name, reference_type, **kwargs) + + # Find any current references to self and remove them, unless it is the target if exclusive: for obj in allobjects: references = getattr(obj, reference_var) if not references: continue - elif reference_name is not None and obj.uid == reference_name: + elif target_obj is not None and obj is target_obj: continue elif isinstance(references, list) and self in references: update_needed = True @@ -1628,40 +1634,39 @@ def _update_reference_in_objects( if update: obj.update(reference_var) - # Add new reference to self in requested object - if reference_name is not None: - obj = parent.find_or_create(reference_name, reference_type, **kwargs) - var = getattr(obj, reference_var) + # Add new reference to self in target object + if target_obj is not None: + var = getattr(target_obj, reference_var) if var_type == "list": if var is None: update_needed = True setattr( - obj, + target_obj, reference_var, [ self, ], ) if update: - obj.update(reference_var) + target_obj.update(reference_var) elif not isinstance(var, list): if var != self and var != str(self): update_needed = True - setattr(obj, reference_var, [var, self]) + setattr(target_obj, reference_var, [var, self]) if update: - obj.update(reference_var) + target_obj.update(reference_var) elif self not in var and str(self) not in var: update_needed = True var.append(self) if update: - obj.update(reference_var) + target_obj.update(reference_var) elif var != self and var != str(self): update_needed = True - setattr(obj, reference_var, self) + setattr(target_obj, reference_var, self) if update: - obj.update(reference_var) + target_obj.update(reference_var) if return_type == "object": - return obj + return target_obj if return_type == "bool": return update_needed @@ -1678,6 +1683,8 @@ def _set_reference( running_config, return_type, name_only, + parent_type=None, + parent_name=None, **kwargs ): """Used by helper methods to set references between objects @@ -1685,6 +1692,10 @@ def _set_reference( For example, set_zone() would set the zone for an interface by creating a reference from the zone to the interface. If the desired reference already exists then nothing happens. + When parent_type and parent_name are provided, handles nested references + where reference_type objects are children of parent_type objects + (e.g., LogicalRouter -> Vrf -> interface). + This function has two modes: refresh=True and refresh=False. You should only ever use refresh=False if: @@ -1699,11 +1710,35 @@ def _set_reference( if return_type not in ("bool", "object"): raise ValueError("Unknown return_type specified: {0}".format(return_type)) - # Get all the objects and the parent + object_type = parent_type if parent_type is not None else reference_type + parent, allobjects = self._get_all_objects_by_type( - reference_type, refresh, running_config, name_only, reference_var + object_type, + refresh, + running_config, + name_only=name_only, + reference_var=reference_var, ) + if parent_type is not None: + for parent_obj in allobjects: + reference_type.refreshall(parent_obj) + + target_parent = next( + (p for p in allobjects if p.uid == parent_name), None + ) + if not target_parent: + target_parent = parent.find_or_create( + parent_name, parent_type + ) + + all_children = [] + for parent_obj in allobjects: + all_children.extend(parent_obj.findall(reference_type)) + + parent = target_parent + allobjects = all_children + return self._update_reference_in_objects( parent, allobjects, diff --git a/panos/network.py b/panos/network.py index b37bebab..b5a6055f 100644 --- a/panos/network.py +++ b/panos/network.py @@ -530,14 +530,12 @@ def set_logical_router( running_config=False, return_type="object", vrf_name="default", - **kwargs ): - """adds the given interface to the VRF by name. + """Set the logical router for this interface - This is more complicated than `set_virtual_router` as the logical routers have child VRF child elements, which - is where the interfaces are configured. - - This will use the VRF name 'default' by default. + Creates a reference to this interface in the specified logical router's + VRF and removes references to this interface from all other logical + routers. The logical router will be created if it doesn't exist. Args: lr_name (str): The name of the LogicalRouter or @@ -547,56 +545,27 @@ def set_logical_router( update (bool): Apply the changes to the device (Default: False) running_config: If refresh is True, refresh from the running configuration (Default: False) - vrf_name (str): Sets the vrf inside the LR. (Default: 'default') + vrf_name (str): The VRF name inside the LR. (Default: 'default') return_type (str): Specify what this function returns, can be either 'object' (the default) or 'bool'. If this is 'object', - then the return value is the LogicalRouter in question. If + then the return value is the Vrf in question. If this is 'bool', then the return value is a boolean that tells you about if the live device needs updates (update=False) or was updated (update=True). """ - - # First we get all the logical routers - parent, all_logical_routers = self._get_all_objects_by_type( - LogicalRouter, + return self._set_reference( + vrf_name, + Vrf, + "interface", + "list", + True, refresh, + update, running_config, - name_only=False, - reference_var="vrf", - ) - target_lr: LogicalRouter | None - - target_lr = next((lr for lr in all_logical_routers if lr.uid == lr_name), None) - if not target_lr: - # If the LR isn't found, create it instead - target_lr = LogicalRouter(name=lr_name) - parent.add(target_lr) - vrf = Vrf(name=vrf_name) - target_lr.add(vrf) - target_lr.create() - - # Remove interface from other LRs first - for lr in all_logical_routers: - Vrf.refreshall(lr) - if lr.name != lr_name: - for vrf in lr.findall(Vrf): - if vrf.interface: - if self.name in vrf.interface: - vrf.interface.remove(self.name) - if update: - vrf.update("interface") - - return self._update_reference_in_objects( - target_lr, - target_lr.findall(Vrf), - reference_name=vrf_name, - reference_var="interface", - reference_type=Vrf, - var_type="list", - return_type=return_type, - update=update, - exclusive=True, - **kwargs + return_type, + False, + parent_type=LogicalRouter, + parent_name=lr_name, ) def get_counters(self):