diff --git a/examples/01_icetray/05_convert_i3_files_advanced.py b/examples/01_icetray/05_convert_i3_files_advanced.py index 4d6f3b6f6..c83356466 100644 --- a/examples/01_icetray/05_convert_i3_files_advanced.py +++ b/examples/01_icetray/05_convert_i3_files_advanced.py @@ -91,7 +91,6 @@ def main( mctree="I3MCTree", mmctracklist="MMCTrackList", extractor_name=f"calorimetry_pad_{str(padding)}", - daughters=False, is_corsika=False, ) diff --git a/src/graphnet/data/extractors/icecube/i3calorimetry.py b/src/graphnet/data/extractors/icecube/i3calorimetry.py index 09f98e73b..90023e7f6 100644 --- a/src/graphnet/data/extractors/icecube/i3calorimetry.py +++ b/src/graphnet/data/extractors/icecube/i3calorimetry.py @@ -1,6 +1,6 @@ """Extract all the visible particles entering the volume.""" -from typing import Dict, Any, TYPE_CHECKING, Tuple, List +from typing import Dict, Any, TYPE_CHECKING, Tuple, Union, List from .utilities.gcd_hull import GCD_hull from .i3extractor import I3Extractor @@ -8,20 +8,40 @@ import numpy as np from graphnet.utilities.imports import has_icecube_package +from copy import deepcopy +from collections import deque if has_icecube_package() or TYPE_CHECKING: from icecube import ( icetray, dataclasses, MuonGun, + simclasses, ) # pyright: reportMissingImports=false + DARK = dataclasses.I3Particle.ParticleShape.Dark + class I3Calorimetry(I3Extractor): """Event level energy labeling for IceCube data. This class extracts cumulative energy information from all visible - particles entering the detector volume, during the event. + particles entering the detector volume, during the event. The recorded energy is split into a "target" and "background" contribution, where the target contribution consists of all particles that downstream of neutrino primaries (or only the highest energy neutrino primary if the corresponding flag is set) and the background contribution consists of all other particles. The recorded energy is further split into a "track" and "cascade" contribution, where the track contribution consists of all energy deposited by particles that are classified as tracks (i.e. if a track is recorded in the MMCTrackList) and the cascade contribution consists of all energy deposited by particles that are not classified as tracks. The recorded energy varies depending on whether or not the entrance_energy flag is set. If the entrance_energy flag is set, the energy recorded is the energy of all visible particles entering the volume as they enter the volume. If the entrance_energy flag is not set, then the recorded energy is only the energy deposited inside the volume i.e. if a muon enters the volume with 100 GeV and leaves with 80 GeV, then the track energy recorded would be 20 GeV. + + Returns a dictionary with the following keys + - e_track_target: Energy entering/deposited by target tracks. + - e_cascade_target: Energy entering/deposited by target cascades. + - e_target: Total energy entering/deposited by target particles. + - e_track_bkg: Energy entering/deposited by background tracks. + - e_cascade_bkg: Energy entering/deposited by background cascades. + - e_bkg: Total energy entering/deposited by background particles. + - fraction_target_total: Fraction of total recorded energy that is from target particles. + - fraction_target_primary: Fraction of primar(y/ies) energy that is recorded as entering/deposited by target particles. + - fraction_cascade_target: Fraction of target energy that is from cascades. + - e_track_total: Total energy entering/deposited by tracks. + - e_cascade_total: Total energy entering/deposited by cascades. + - e_total: Total energy entering/deposited by all particles. + - fraction_cascade_total: Fraction of total recorded energy that is from cascades. """ def __init__( @@ -30,9 +50,8 @@ def __init__( mctree: str = "I3MCTree", mmctracklist: str = "MMCTrackList", extractor_name: str = "I3Calorimetry", - daughters: bool = False, highest_energy_primary: bool = False, - cascade_deposited_only: bool = True, + entrance_energy: bool = False, **kwargs: Any, ) -> None: """Create a ConvexHull object from the GCD file. @@ -42,37 +61,16 @@ def __init__( mctree: Name of the I3MCTree in the frame. mmctracklist: Name of the MMCTrackList in the frame. extractor_name: Name of the extractor. - daughters: If True, only calculate energies for particles - that are daughters of the primary. - highest_energy_primary: If True, takes into account only the - primary with the highest energy. - NOTE: Only makes a difference if daughters is False - and the event is not a Corsika event. - cascade_deposited_only: If True, consider only energies from - cascades that are marked as visible. If False the total - energy of a cascade is counted. - - Variable explanation: - - e_entrance_track: Total energy of tracks entering the hull. - - e_deposited_track: Total energy deposited by tracks in the hull. - - e_cascade: Total energy of cascade particles in the hull. - - e_visible: Total energy of particles entering the hull. - NOTE: if daughters is True, this is the total visible energy - of daughter particles of the primary particles. If this is 0 - that means that all the light in the detector comes from - particles that are daughters of coincident primaries. - - fraction_primary: Fraction of `e_visible` compared to - the primary energy. - - fraction_cascade: Fraction of the total energy that is - deposited by cascade particles compared to the total energy. + highest_energy_primary: If True, only consider particles that are + daughters of the highest energy primary. + entrance_energy: If True, consider entrance energy. """ # Member variable(s) self.hull = hull self.mctree = mctree self.mmctracklist = mmctracklist - self.daughters = daughters self.highest_energy_primary = highest_energy_primary - self.cascade_deposited_only = cascade_deposited_only + self.entrance_energy = entrance_energy # Base class constructor super().__init__(extractor_name=extractor_name, **kwargs) @@ -80,259 +78,308 @@ def __call__(self, frame: "icetray.I3Frame") -> Dict[str, Any]: """Extract all the visible particles entering the volume.""" output = {} if self.frame_contains_info(frame): + target_tree, bkg_tree = self.split_mc_tree( + frame, highest_energy_primary=self.highest_energy_primary + ) + if len(target_tree) + len(bkg_tree) != len(frame[self.mctree]): + raise ValueError( + f"Split mctree has different number of particles than original mctree\nOriginal mctree: {len(frame[self.mctree])}\nTarget tree: {len(target_tree)}\nBkg tree: {len(bkg_tree)}\nHighest energy primary flag: {self.highest_energy_primary}\nEvent header: {frame['I3EventHeader']}" + ) - primaries = self.get_primaries( - frame, - self.daughters, - self.highest_energy_primary, + # For the target we consider either all neutrino primary products or only the highest energy primary of the neutrino depending on the flag. + target_primaries = self.get_primaries( + target_tree, + daughters=True, + highest_energy_primary=self.highest_energy_primary, + ) + target_primaries = self.check_primary_energy( + target_tree, target_primaries ) + # For background we consider everything in the background tree. + bkg_primaries = self.get_primaries( + bkg_tree, daughters=False, highest_energy_primary=False + ) + bkg_primaries = self.check_primary_energy(bkg_tree, bkg_primaries) - if not len(primaries) == 0: - - MMCTrackList = frame[self.mmctracklist] - # Filter tracks that are not daughters of the desired - if self.daughters: - temp_MMCTrackList = [] - for track in MMCTrackList: - for p in primaries: - if frame[self.mctree].is_in_subtree( - p.id, track.GetI3Particle().id - ): - temp_MMCTrackList.append(track) - break - MMCTrackList = temp_MMCTrackList - - # Create a lookup dict for the tracks - track_lookup = {} - for track in MuonGun.Track.harvest( - frame[self.mctree], MMCTrackList - ): - track_lookup[track.id] = track + target_primaries_energy = sum([p.energy for p in target_primaries]) + bkg_primaries_energy = sum([p.energy for p in bkg_primaries]) - e_cascade, e_dep_track, e_ent_track = self.get_energies( - frame, primaries, track_lookup + if len(target_tree) > 0: + e_track_target = self.total_track_energy( + frame, + target_tree, + entrance_energy=self.entrance_energy, ) - - primary_energy = sum([p.energy for p in primaries]) else: - e_ent_track = np.nan - e_dep_track = np.nan - e_cascade = np.nan - primary_energy = np.nan + e_track_target = 0.0 + # Sanity check ensuring no double counting - e_total = e_ent_track + e_cascade + if not (e_track_target <= target_primaries_energy * (1 + 1e-6)): + raise ValueError( + f"Energy deposited in target is greater than primary energy: {e_track_target} > {target_primaries_energy}\nEvent header: {frame['I3EventHeader']}" + ) + if len(bkg_tree) > 0: + e_track_bkg = self.total_track_energy( + frame, + bkg_tree, + entrance_energy=self.entrance_energy, + ) + else: + e_track_bkg = 0.0 - # In case all particles are considered and - # there is no energy deposited in the hull, - # we warn the user. - if all( - ( - not self.daughters, - not self.highest_energy_primary, - e_total == 0, + if len(target_tree) > 0: + e_cascade_target = self.total_cascade_energy( + target_tree, target_primaries ) - ): + else: + e_cascade_target = 0.0 + if len(bkg_tree) > 0: + e_cascade_bkg = self.total_cascade_energy( + bkg_tree, bkg_primaries + ) + else: + e_cascade_bkg = 0.0 + + e_total_target = e_track_target + e_cascade_target + e_total_bkg = e_track_bkg + e_cascade_bkg + + e_total = e_total_target + e_total_bkg + + if e_total == 0.0: self.warning( "No energy deposited in the hull, " "Think about increasing the padding." f"\nCurrent padding: {self.hull.padding}" - f"\nTotal energy: {e_total}" - f"\nTrack energy: {e_ent_track}" - f"\nCascade energy: {e_cascade}" f"\nEvent header: {frame['I3EventHeader']}" ) - # Check only in the case that there were primaries - if not len(primaries) == 0 and (not np.isnan(e_total)): - - # total energy should always be less than the primary energy - assert e_total <= ( - primary_energy * (1 + 1e-6) - ), "Total energy on entrance is greater than primary energy\ - \nTotal energy: {}\ - \nPrimary energy: {}\ - \nTrack energy: {}\ - \nCascade energy: {}\ - {}".format( - e_total, - primary_energy, - e_ent_track, - e_cascade, - frame["I3EventHeader"], + if not ( + e_total_target <= (target_primaries_energy * (1 + 1e-6)) + or (e_total_target - target_primaries_energy < 0.5) + ): + raise ValueError( + "Total energy is greater than primary energy\n" + f"Total energy: {e_total_target}\n" + f"Primary energy: {target_primaries_energy}\n" + f"Track deposited energy: {e_track_target}\n" + f"Cascade deposited energy: {e_cascade_target}\n" + f"{frame['I3EventHeader']}" ) - assert ( - primary_energy > 0 - ), "Primary energy is 0, this should not happen.\ - \nTotal energy: {}\ - \nTrack energy: {}\ - \nCascade energy: {}\ - {}".format( - e_total, - e_ent_track, - e_cascade, - frame["I3EventHeader"], + if not ( + e_total_bkg <= (bkg_primaries_energy * (1 + 1e-6)) + or (e_total_bkg - bkg_primaries_energy < 0.5) + ): + raise ValueError( + "Total background energy is greater than primary energy\n" + f"Total background energy: {e_total_bkg}\n" + f"Background primary energy: {bkg_primaries_energy}\n" + f"Track deposited background energy: {e_track_bkg}\n" + f"Cascade deposited background energy: {e_cascade_bkg}\n" + f"{frame['I3EventHeader']}" ) - fraction_primary = e_total / primary_energy - cascade_fraction = None + fraction_target_total = ( + e_total_target / e_total if e_total > 0 else 0.0 + ) + target_cascade_fraction = ( + e_cascade_target / e_total_target + if e_total_target > 0 + else 0.0 + ) + if e_total > 0: - cascade_fraction = e_cascade / e_total + cascade_fraction_tot = ( + e_cascade_target + e_cascade_bkg + ) / e_total + if target_primaries_energy > 0: + fraction_primary = e_total_target / target_primaries_energy + else: + fraction_primary = None output.update( { - "e_entrance_track_" + self._extractor_name: e_ent_track, - "e_deposited_track_" + self._extractor_name: e_dep_track, - "e_cascade_" + self._extractor_name: e_cascade, - "e_visible_" + self._extractor_name: e_total, - "fraction_primary_" + "e_track_target_" + self._extractor_name: e_track_target, + "e_cascade_target_" + + self._extractor_name: e_cascade_target, + "e_target_" + self._extractor_name: e_total_target, + "e_track_bkg_" + self._extractor_name: e_track_bkg, + "e_cascade_bkg_" + self._extractor_name: e_cascade_bkg, + "e_bkg_" + self._extractor_name: e_total_bkg, + "fraction_target_total_" + + self._extractor_name: fraction_target_total, + "fraction_target_primary_" + self._extractor_name: fraction_primary, - "fraction_cascade_" - + self._extractor_name: cascade_fraction, + "fraction_cascade_target_" + + self._extractor_name: target_cascade_fraction, + "e_track_total_" + + self._extractor_name: e_track_target + + e_track_bkg, + "e_cascade_total_" + + self._extractor_name: e_cascade_target + + e_cascade_bkg, + "e_total_" + self._extractor_name: e_total, + "fraction_cascade_total_" + + self._extractor_name: ( + cascade_fraction_tot if e_total > 0 else 0.0 + ), } ) output = {k: v for k, v in output.items() if k not in self._exclude} return output - def get_energies( + def frame_contains_info(self, frame: "icetray.I3Frame") -> bool: + """Check if the frame contains the necessary information.""" + return self.mctree in frame and self.mmctracklist in frame + + def total_track_energy( self, frame: "icetray.I3Frame", - particles: List["dataclasses.I3Particle"], - track_lookup: Dict["icetray.I3ParticleID", "icetray.I3Particle"], - ) -> Tuple[float, float, float]: - """Get the total energy of cascade particles on entrance.""" - e_cascade = 0 - e_dep_track = 0 - e_ent_track = 0 + mctree: "dataclasses.I3MCTree", + entrance_energy: bool = False, + ) -> float: + """Get the total energy deposited by tracks entering the volume. - if len(particles) == 0: - return e_cascade, e_dep_track, e_ent_track - - for particle in particles: - length = particle.length - if length != length: - length = 0 - # If the particle is a track in the MMCTrackList take the - # energy at the entrance and exit of the hull. - # NOTE: We do not consider daughters of tracks, - # because they are already included in the track energy. - if particle.is_track & (particle.id in track_lookup): - track = track_lookup[particle.id] - - # Find distance to entrance and exit from sampling volume - intersections = self.hull.surface.intersection( - track.pos, track.dir + If entrance_energy is True, return the total energy entering the + volume as tracks instead of the energy deposited. + """ + energy = 0 + + if self._is_corsika: + mmc_track_list = frame[self.mmctracklist] + else: + mmc_track_list = self.filter_track_list( + mctree, frame[self.mmctracklist] + ) + + track_list = deque(MuonGun.Track.harvest(mctree, mmc_track_list)) + + while len(track_list) > 0: + track = track_list.popleft() + + try: + particle = mctree.get_particle(track.id) + except RuntimeError: + # If the particle does not exist in the mctree, that means a particle further up was processed and therefore it should not be counted + continue + + # Find distance to entrance and exit from sampling volume + intersections = self.hull.surface.intersection( + track.pos, track.dir + ) + + # Check if the track actually enters the volume. Values are NAN if the ray does not intersect the hull, negative if the intersection is behind the "origin". Uncertain if we can have negative intersections for both first and second intersection without it being converted to NAN (no intersection) but we check for both to be sure. + if not ( + ( + np.isfinite(intersections.first) + and (intersections.first < particle.length) ) - # Get the corresponding energies - try: - e0 = track.get_energy(intersections.first) - e1 = track.get_energy(intersections.second) - - # Catch MuonGun errors - except RuntimeError as e: - if ( - "sum of losses is smaller than " - "energy at last checkpoint" in str(e) - ): - hdr = frame["I3EventHeader"] - e.add_note(f"Error in MuonGun track in event {hdr}") - self.warning(f"Skipping bad event {hdr}: {e}") - e0 = np.nan - e1 = np.nan - e_cascade = np.nan - continue # skip this frame - else: - raise # re-raise unexpected errors - - e_dep_track += e0 - e1 - e_ent_track += e0 - # if the particle is not in the hull, but has daughters, - # we add the energies of the daughters. - elif not self.is_in_hull(particle): - daughters = dataclasses.I3MCTree.get_daughters( - frame[self.mctree], particle + and ( + np.isfinite(intersections.second) + and intersections.second > 0 ) - if len(daughters) == 0: - continue - ( - e_cascade, - e_dep_track, - e_ent_track, - ) = tuple( - np.add( - (e_cascade, e_dep_track, e_ent_track), - self.get_energies( - frame, - daughters, - track_lookup, - ), + ): + continue + + # Get the corresponding energies + try: + e0 = track.get_energy(intersections.first) + e1 = track.get_energy(intersections.second) + + except RuntimeError as e: + if ( + "sum of losses is smaller than " + "energy at last checkpoint" in str(e) + ): + hdr = frame["I3EventHeader"] + e.add_note(f"Error in MuonGun track in event {hdr}") + self.warning( + f"Skipping bad track {hdr}: {e}" + f"\nTotal energy of offending particle: {particle.energy}" ) - ) - # If the particle is a cascade in the hull, we add its energy. - elif particle.is_cascade: - if self.cascade_deposited_only: - # Check wether the cascade is made up of smaller segments - # in this case the shape is dark and we want to count - # the energy of its daughters. - if ( - particle.shape - != dataclasses.I3Particle.ParticleShape.Dark - ): - e_cascade += particle.energy - else: - ( - e_cascade, - e_dep_track, - e_ent_track, - ) = tuple( - np.add( - (e_cascade, e_dep_track, e_ent_track), - self.get_energies( - frame, - dataclasses.I3MCTree.get_daughters( - frame[self.mctree], particle - ), - track_lookup, - ), - ) - ) + continue else: - # In this case we consider the total cascade - # energy and therefore do not look at the daughters - e_cascade += particle.energy - # The particle is in the hull and not a track in the MMCTrackList, - # or a cascade, so we look at its daughters. - # Could be a NuMu interacting within the hull. - else: - ( - e_cascade, - e_dep_track, - e_ent_track, - ) = tuple( - np.add( - (e_cascade, e_dep_track, e_ent_track), - self.get_energies( - frame, - dataclasses.I3MCTree.get_daughters( - frame[self.mctree], particle - ), - track_lookup, - ), - ) - ) + raise - return e_cascade, e_dep_track, e_ent_track + # Accumulate + if entrance_energy: + energy += e0 + # if we are looking at the entrance energy then all energy entering the volume as a track is considered "track energy" even if it is later deposited in a cascade, so we remove all descendants of the track from the mctree to avoid double counting + mctree.erase(track.id) + else: + energy += e0 - e1 + # if we are looking at the deposited energy then we only want to remove the tracks that have either deposited all their energy in the volume or left the volume again thus descendants cannot produce cascades in the volume. + if (e1 == 0) or (intersections.second < particle.length): + mctree.erase(track.id) + return energy - def frame_contains_info(self, frame: "icetray.I3Frame") -> bool: - """Check if the frame contains the necessary information.""" - return self.mctree in frame and self.mmctracklist in frame + def total_cascade_energy( + self, + mctree: "dataclasses.I3MCTree", + primaries: "dataclasses.ListI3Particle", + ) -> float: + """Get the total energy of cascade particles on entrance.""" + particles = deque(primaries) - def is_in_hull(self, particle: "dataclasses.I3Particle") -> bool: - """Check if a particle is in the hull.""" - pos = np.array(particle.pos) - direc = np.array([particle.dir.x, particle.dir.y, particle.dir.z]) - length = particle.length if particle.length is not None else 0 - pos = pos + direc * length + if len(particles) == 0: + return 0.0 + + pos_list, direc_list, length_list, cascade_bool, energies = ( + [], + [], + [], + [], + [], + ) + + while len(particles) > 0: + p = particles.popleft() + p_children = mctree.get_daughters(p) + if len(p_children) > 0: + particles.extend(p_children) + continue + if p.is_track or p.shape == DARK: + continue + + pos_list.append([p.pos.x, p.pos.y, p.pos.z]) + direc_list.append([p.dir.x, p.dir.y, p.dir.z]) + length_list.append(p.length) + cascade_bool.append(p.is_cascade) + energies.append(p.energy) + + if len(energies) == 0: + return 0.0 + + length = np.array(length_list).astype(float) + length[np.isnan(length)] = 0 + pos = np.asarray(pos_list) + direc = np.asarray(direc_list) + cascade_bool = np.array(cascade_bool) + energies = np.array(energies) + pos = (pos.T + direc.T * length).T + in_hull = self.hull.point_in_hull(pos) + + return np.sum(energies[cascade_bool & in_hull]) + + def filter_track_list( + self, + mctree: "dataclasses.I3MCTree", + track_list: "simclasses.I3MMCTrackList", + ) -> "simclasses.I3MMCTrackList": + """Filter the track list based on the mctree provided. - return self.hull.point_in_hull(pos) + (This function is only meant to run on target/bkg split trees) + """ + filtered_track_list = [] + for track in track_list: + try: + mctree.get_particle(track.particle.id) + filtered_track_list.append(track) + except RuntimeError as e: + if "particleID not found" in str(e): + # if particle is not found in the mctree then it should not be included as it in the other tree. + continue + else: + raise e + return simclasses.I3MMCTrackList(filtered_track_list) diff --git a/src/graphnet/data/extractors/icecube/i3extractor.py b/src/graphnet/data/extractors/icecube/i3extractor.py index 9b8237fe5..1573d6e33 100644 --- a/src/graphnet/data/extractors/icecube/i3extractor.py +++ b/src/graphnet/data/extractors/icecube/i3extractor.py @@ -14,6 +14,8 @@ dataclasses, ) # pyright: reportMissingImports=false +from copy import deepcopy + class I3Extractor(Extractor): """Base class for extracting information from physics I3-frames. @@ -111,7 +113,7 @@ def __call__(self, frame: "icetray.I3Frame") -> dict: def check_primary_energy( self, - frame: "icetray.I3Frame", + mctree: "dataclasses.I3MCTree", primaries: Union[ "dataclasses.ListI3Particle", "dataclasses.I3Particle" ], @@ -122,7 +124,7 @@ def check_primary_energy( primary particle(s) are returned instead. Args: - frame: I3Frame object. + mctree: I3MCTree object. primaries: Primary particle or a list of primary particles. """ assert hasattr( @@ -132,7 +134,7 @@ def check_primary_energy( if isinstance(primaries, dataclasses.ListI3Particle): new_primaries = dataclasses.ListI3Particle() for primary in primaries: - primary = self.check_primary_energy(frame, primary) + primary = self.check_primary_energy(mctree, primary) if isinstance(primary, dataclasses.ListI3Particle): new_primaries.extend(primary) elif isinstance(primary, dataclasses.I3Particle): @@ -151,9 +153,7 @@ def check_primary_energy( if primary.energy != primary.energy: self.warning_once("Primary energy is nan checking daughters") - daughters = dataclasses.I3MCTree.get_daughters( - frame[self.mctree], primary - ) + daughters = dataclasses.I3MCTree.get_daughters(mctree, primary) if len(daughters) == 0: raise ValueError( "Primary energy is nan and no daughters found" @@ -165,7 +165,7 @@ def check_primary_energy( def get_primaries( self, - frame: "icetray.I3Frame", + mctree: "dataclasses.I3MCTree", daughters: bool = False, highest_energy_primary: bool = True, ) -> "dataclasses.ListI3Particle": @@ -183,7 +183,7 @@ def get_primaries( Corsika case. Input: - frame: I3Frame object + mctree: I3MCTree object daughters: If True only daughters of the primary neutrino are returned highest_energy_primary: If True, return the primary with the highest energy. If False, return all primaries. @@ -195,7 +195,7 @@ def get_primaries( ), "mctree should be instantiated by subclass" if not self._is_corsika: - primaries = frame[self.mctree].get_primaries() + primaries = mctree.get_primaries() if daughters: primaries = [ p @@ -217,16 +217,13 @@ def get_primaries( # get the original primary neutrino(s) primary_nus = [ - p - for p in frame[self.mctree].get_primaries() - if p.is_neutrino + p for p in mctree.get_primaries() if p.is_neutrino ] # recursively search for in-ice neutrino daughters primaries = self.find_in_ice_daughters( - frame, + mctree, primary_nus, - self.mctree, ) # This is not expected to happen @@ -246,14 +243,13 @@ def get_primaries( primaries = dataclasses.ListI3Particle(primaries) if self._is_corsika: - primaries = frame[self.mctree].get_primaries() + primaries = mctree.get_primaries() return primaries def find_in_ice_daughters( self, - frame: "icetray.I3Frame", + mctree: "dataclasses.I3MCTree", particles: "dataclasses.ListI3Particle", - mctree: str, ) -> "dataclasses.ListI3Particle": """Find in-ice particles in the frame.""" if particles == []: @@ -268,9 +264,71 @@ def find_in_ice_daughters( else: ret.extend( self.find_in_ice_daughters( - frame, - frame[mctree].get_daughters(p.id), - mctree=mctree, + mctree, + mctree.get_daughters(p.id), ) ) return ret + + def split_mc_tree( + self, frame: "icetray.I3Frame", highest_energy_primary: bool = True + ) -> "dataclasses.I3MCTree": + """Split the mctree in subtrees corresponding to each primary particle. + + Into a subtree containing only the daughters of the primary + particle, and a subtree containing the rest of the particles in + the event. + """ + assert hasattr( + self, "mctree" + ), "mctree should be instantiated by subclass" + + main_tree = deepcopy(frame[self.mctree]) + bkg_tree = deepcopy(frame[self.mctree]) + + if self._is_corsika: + # create empty main tree and return bkg tree. + return dataclasses.I3MCTree(), bkg_tree + + all_primaries = main_tree.get_primaries() + if highest_energy_primary: + # grab the id of the highest energy primary + in_ice_daughters = self.find_in_ice_daughters( + main_tree, [p for p in all_primaries if p.is_neutrino] + ) + energies = np.array([p.energy for p in in_ice_daughters]) + p_highest = np.array(in_ice_daughters)[np.argmax(energies)] + parent_ids = [ + p.id for p in self.get_all_parents(main_tree, p_highest) + ] + parent_ids.append(p_highest.id) + + for primary in all_primaries: + if primary.is_neutrino: + if highest_energy_primary: + if primary.id not in parent_ids: + main_tree.erase(primary.id) + else: + bkg_tree.erase(primary.id) + else: + bkg_tree.erase(primary.id) + else: + main_tree.erase(primary.id) + return main_tree, bkg_tree + + def get_all_parents( + self, + mctree: "dataclasses.I3MCTree", + particle: "dataclasses.I3Particle", + ) -> list: + """Get all parents of a particle.""" + assert hasattr( + self, "mctree" + ), "mctree should be instantiated by subclass" + + parents = [] + while mctree.has_parent(particle.id): + parent = mctree.parent(particle.id) + parents.append(parent) + particle = parent + return parents diff --git a/src/graphnet/data/extractors/icecube/i3highesteparticleextractor.py b/src/graphnet/data/extractors/icecube/i3highesteparticleextractor.py index 19228d21c..b303ae17a 100644 --- a/src/graphnet/data/extractors/icecube/i3highesteparticleextractor.py +++ b/src/graphnet/data/extractors/icecube/i3highesteparticleextractor.py @@ -66,7 +66,9 @@ def __call__(self, frame: "icetray.I3Frame") -> Dict[str, Any]: HEParticle.energy = 0 primary_energy = sum( prim.energy - for prim in self.get_primaries(frame, self.daughters) + for prim in self.get_primaries( + frame[self.mctree], self.daughters + ) ) distance = -1.0 EonEntrance = 0.0 @@ -178,8 +180,10 @@ def get_tracks( Args: frame: I3Frame object """ - primaries = self.get_primaries(frame, self.daughters) - primaries = [self.check_primary_energy(frame, p) for p in primaries] + primaries = self.get_primaries(frame[self.mctree], self.daughters) + primaries = [ + self.check_primary_energy(frame[self.mctree], p) for p in primaries + ] MMCTrackList = frame[self.mmctracklist] if self.daughters: @@ -419,9 +423,10 @@ def highest_energy_starting( # noqa: C901 containment = GN_containment_types.no_intersect.value visible_length = 0.0 if self.daughters: - primaries = self.get_primaries(frame, self.daughters) + primaries = self.get_primaries(frame[self.mctree], self.daughters) primaries = [ - self.check_primary_energy(frame, p) for p in primaries + self.check_primary_energy(frame[self.mctree], p) + for p in primaries ] particles = self.get_descendants(frame, primaries)