From dd01b6b97f5a6006faa05697cd2c511d7a51a492 Mon Sep 17 00:00:00 2001 From: BrunoSanchez Date: Thu, 28 May 2026 09:01:11 -0700 Subject: [PATCH] Fix matching to account for template fakes and mutual match strategy --- .../pipe/tasks/matchDiffimSourceInjected.py | 169 ++++++++++++++---- tests/test_matchSourceInjected.py | 3 +- 2 files changed, 134 insertions(+), 38 deletions(-) diff --git a/python/lsst/pipe/tasks/matchDiffimSourceInjected.py b/python/lsst/pipe/tasks/matchDiffimSourceInjected.py index d05f606e9..6061d6cc2 100644 --- a/python/lsst/pipe/tasks/matchDiffimSourceInjected.py +++ b/python/lsst/pipe/tasks/matchDiffimSourceInjected.py @@ -25,6 +25,7 @@ "MatchInjectedToAssocDiaSourceConfig"] import astropy.units as u +from astropy.table import Table, join, vstack import numpy as np from scipy.spatial import cKDTree @@ -43,9 +44,9 @@ class MatchInjectedToDiaSourceConnections( dimensions=("instrument", "visit", "detector")): - injectedCat = connTypes.Input( + injectionCat = connTypes.Input( doc="Catalog of sources injected in the images.", - name="{fakesType}_pvi_catalog", + name="{fakesType}VisitDetectorFakeSourceCat", storageClass="ArrowAstropy", dimensions=("instrument", "visit", "detector"), ) @@ -66,7 +67,7 @@ class MatchInjectedToDiaSourceConnections( "diaSrc. The schema is the union of the schemas for " "``fakeCat`` and ``diaSrc``.", name="{fakesType}{coaddName}Diff_matchDiaSrc", - storageClass="DataFrame", + storageClass="ArrowAstropy", dimensions=("instrument", "visit", "detector"), ) @@ -111,12 +112,13 @@ class MatchInjectedToDiaSourceTask(PipelineTask): _DefaultName = "matchInjectedToDiaSource" ConfigClass = MatchInjectedToDiaSourceConfig - def run(self, injectedCat, diffIm, diaSources): + # def run(self, injectedCat, injectedTemplateCat, diffIm, diaSources): + def run(self, injectionCat, diffIm, diaSources): """Match injected sources to detected diaSources within a difference image bound. Parameters ---------- - injectedCat : `astropy.table.table.Table` + injectionCat : `astropy.table.table.Table` Table of catalog of synthetic sources to match to detected diaSources. diffIm : `lsst.afw.image.Exposure` Difference image where ``diaSources`` were detected. @@ -128,24 +130,32 @@ def run(self, injectedCat, diffIm, diaSources): Results struct with components. - ``matchedDiaSources`` : Fakes matched to input diaSources. Has - length of ``injectedCalexpCat``. (`pandas.DataFrame`) + length of ``injectionCat``. (`astropy.table.Table`) """ if self.config.doMatchVisit: - fakeCat = self._trimFakeCat(injectedCat, diffIm) + fakeCat = self._trimFakeCat(injectionCat, diffIm) else: - fakeCat = injectedCat + fakeCat = injectionCat if self.config.doForcedMeasurement: self._estimateFakesSNR(fakeCat, diffIm) - return self._processFakes(fakeCat, diaSources) + # Split the fake catalog into the initial injections and the variable sources themselves, + # which are generated as duplicates of the initial injections with a twin_id column. + # We then match only the initial injections to the diaSources, + # and then add back in the variable sources by matching them to their twins + initialFakeCat, variableDoublesFakeCat = self._splitVariables(fakeCat) + matchedFakes = self._processFakes(initialFakeCat, diaSources) + fullMatchedFakes = self._add_variables_to_matched(matchedFakes, variableDoublesFakeCat) - def _estimateFakesSNR(self, injectedCat, diffIm): + return Struct(matchDiaSources=fullMatchedFakes) + + def _estimateFakesSNR(self, injectionCat, diffIm): """Estimate the signal-to-noise ratio of the fakes in the given catalog. Parameters ---------- - injectedCat : `astropy.table.Table` + injectionCat : `astropy.table.Table` Catalog of synthetic sources to estimate the S/N of. **This table will be modified in place**. diffIm : `lsst.afw.image.Exposure` @@ -176,8 +186,8 @@ def _estimateFakesSNR(self, injectedCat, diffIm): # Create an afw table from the input catalog outputCatalog = afwTable.SourceCatalog(schema) - outputCatalog.reserve(len(injectedCat)) - for row in injectedCat: + outputCatalog.reserve(len(injectionCat)) + for row in injectionCat: outputRecord = outputCatalog.addNew() outputRecord.setId(row['injection_id']) outputRecord.setCoord(lsstGeom.SpherePoint(row["ra"], row["dec"], lsstGeom.degrees)) @@ -205,17 +215,18 @@ def _estimateFakesSNR(self, injectedCat, diffIm): # Add the forced measurement columns to the input catalog for column in forcedSources_table.columns: if "Flux" in column or "flag" in column: - injectedCat["forced_"+column] = forcedSources_table[column] + injectionCat["forced_"+column] = forcedSources_table[column] # Add the SNR columns to the input catalog - for column in injectedCat.colnames: + for column in injectionCat.colnames: if column.endswith("instFlux"): - flux = injectedCat[column] - fluxErr = injectedCat[column+"Err"].copy() + # flux = injectionCat[column] + flux = np.abs(injectionCat[column]) + fluxErr = injectionCat[column+"Err"].copy() fluxErr = np.where( (fluxErr <= 0) | (np.isnan(fluxErr)), np.nanmax(fluxErr), fluxErr) - injectedCat[column+"_SNR"] = flux / fluxErr + injectionCat[column+"_SNR"] = flux / fluxErr def _processFakes(self, injectedCat, diaSources): """Match fakes to detected diaSources within a difference image bound. @@ -238,12 +249,12 @@ def _processFakes(self, injectedCat, diaSources): length of ``fakeCat``. (`pandas.DataFrame`) """ # First match the diaSrc to the injected fakes - injectedCat = injectedCat.to_pandas() + # injectedCat = injectedCat.to_pandas() nPossibleFakes = len(injectedCat) fakeVects = self._getVectors( - np.radians(injectedCat.ra), - np.radians(injectedCat.dec)) + np.radians(injectedCat['ra']), + np.radians(injectedCat['dec'])) diaSrcVects = self._getVectors( diaSources['coord_ra'], diaSources['coord_dec']) @@ -252,16 +263,28 @@ def _processFakes(self, injectedCat, diaSources): dist, idxs = diaSrcTree.query( fakeVects, distance_upper_bound=np.radians(self.config.matchDistanceArcseconds / 3600)) - nFakesFound = np.isfinite(dist).sum() + # handshake matching, that is symmetrize the match by matching the + # diaSrcs back to the fakes and only keeping those matches where the + # same pair is returned + diaSrcTreeBack = cKDTree(fakeVects) + distBack, idxsBack = diaSrcTreeBack.query( + diaSrcVects, + distance_upper_bound=np.radians(self.config.matchDistanceArcseconds / 3600)) + + idxsAux = np.where(np.array(idxs) < len(diaSources), idxs, -1) + idxsBackMatched = np.where(idxsAux >= 0, idxsBack[idxsAux], -1) + idxsMatched = np.where(idxsBackMatched == np.arange(len(injectedCat)), idxs, -1) + distMatched = np.where(idxsBackMatched == np.arange(len(injectedCat)), dist, np.inf) + nFakesFound = np.isfinite(distMatched).sum() self.log.info("Found %d out of %d possible in diaSources.", nFakesFound, nPossibleFakes) # assign diaSourceId to the matched fakes - diaSrcIds = diaSources['id'][np.where(np.isfinite(dist), idxs, 0)] - matchedFakes = injectedCat.assign(diaSourceId=np.where(np.isfinite(dist), diaSrcIds, 0)) - matchedFakes['dist_diaSrc'] = np.where(np.isfinite(dist), 3600*np.rad2deg(dist), -1) - - return Struct(matchDiaSources=matchedFakes) + diaSrcIds = diaSources['id'][np.where(np.isfinite(distMatched), idxsMatched, 0)] + matchedFakes = injectedCat.copy() + matchedFakes['diaSourceId'] = np.where(np.isfinite(distMatched), diaSrcIds, 0) + matchedFakes['dist_diaSrc'] = np.where(np.isfinite(distMatched), 3600*np.rad2deg(distMatched), -1) + return matchedFakes def _getVectors(self, ras, decs): """Convert ra dec to unit vectors on the sphere. @@ -350,6 +373,70 @@ def _trimFakeCat(self, fakeCat, image): return fakeCat[isContainedRaDec & isContainedXy] + def _splitVariables(self, fakeCat): + """Split out the duplicated injections, that are used to generate + variable sources in the fake catalog. + + Parameters + ---------- + fakeCat : `astropy.table.table.Table` + The catalog of fake sources that was input + + Returns + ------- + initialFakeCat : `astropy.table.table.Table` + Subset of the input catalog corresponding to initial sources. + variableDoublesFakeCat : `astropy.table.table.Table` + Subset of the input catalog corresponding to variable sources. + """ + if "twin_id" not in fakeCat.colnames: + self.log.warning("No twin_id column found in fake catalog.") + return fakeCat, None + + isVariable = fakeCat["twin_id"] > 0 + + return fakeCat[~isVariable], fakeCat[isVariable] + + def _add_variables_to_matched(self, matchedFakes, variableDoublesFakeCat): + """Add variable sources back into the matched fakes catalog. + + Parameters + ---------- + matchedFakes : `astropy.table.table.Table` + Catalog of matched fakes to diaSources, corresponding to the static + sources in the input fake catalog. + variableDoublesFakeCat : `astropy.table.table.Table` + Catalog of variable sources in the input fake catalog. + + Returns + ------- + fullMatchedFakes : `astropy.table.table.Table` + Catalog of matched fakes to diaSources, corresponding to both the + static and variable sources in the input fake catalog. + """ + if variableDoublesFakeCat is None: + return matchedFakes + + # For the variable sources, we have a match to diaSources if their twins + # had a match, so we fill the diaSourceId with the diaSourceId of the matched + # twin if it exists and 0 otherwise, and we set the distance to -1 to + # indicate that these are variable sources that were not directly matched + # to diaSources. + variableDoublesFakeCat = variableDoublesFakeCat.copy() + variableDoublesFakeCat['diaSourceId'] = 0 + variableDoublesFakeCat['dist_diaSrc'] = -1 + + # Match variable sources to their twin's matched diaSource + for i, row in enumerate(variableDoublesFakeCat): + twin_id = row['twin_id'] + # Find matching twin in matchedFakes + twin_matches = matchedFakes[matchedFakes['injection_id'] == twin_id] + if len(twin_matches) > 0: + variableDoublesFakeCat['diaSourceId'][i] = twin_matches['diaSourceId'][0] + variableDoublesFakeCat['dist_diaSrc'][i] = twin_matches['dist_diaSrc'][0] + + return vstack([matchedFakes, variableDoublesFakeCat], metadata_conflicts='silent') + class MatchInjectedToAssocDiaSourceConnections( PipelineTaskConnections, @@ -371,7 +458,7 @@ class MatchInjectedToAssocDiaSourceConnections( "diaSrc. The schema is the union of the schemas for " "``fakeCat`` and ``diaSrc``.", name="{fakesType}{coaddName}Diff_matchDiaSrc", - storageClass="DataFrame", + storageClass="ArrowAstropy", dimensions=("instrument", "visit", "detector"), ) matchAssocDiaSources = connTypes.Output( @@ -379,7 +466,7 @@ class MatchInjectedToAssocDiaSourceConnections( "associatedDiaSources. The schema is the union of the schemas for " "``fakeCat`` and ``associatedDiaSources``.", name="{fakesType}{coaddName}Diff_matchAssocDiaSrc", - storageClass="DataFrame", + storageClass="ArrowAstropy", dimensions=("instrument", "visit", "detector"), ) @@ -415,16 +502,26 @@ def run(self, assocDiaSources, matchDiaSources): """ # Match the fakes to the associated sources. For this we don't use the coordinates # but instead check for the diaSources. Since they were present in the table already + if not isinstance(assocDiaSources, Table): + assocDiaSources = Table.from_pandas(assocDiaSources, index=False) + + matchDiaSources["diaSourceId"] = np.asarray(matchDiaSources["diaSourceId"], dtype=np.int64) + assocDiaSources["diaSourceId"] = np.asarray(assocDiaSources["diaSourceId"], dtype=np.int64) + nPossibleFakes = len(matchDiaSources) - matchDiaSources['isAssocDiaSource'] = matchDiaSources.diaSourceId.isin(assocDiaSources.diaSourceId) - assocNFakesFound = matchDiaSources.isAssocDiaSource.sum() + matchDiaSources["isAssocDiaSource"] = np.isin( + matchDiaSources["diaSourceId"], assocDiaSources["diaSourceId"] + ) + assocNFakesFound = matchDiaSources['isAssocDiaSource'].sum() self.log.info("Found %d out of %d possible in assocDiaSources."%(assocNFakesFound, nPossibleFakes)) return Struct( - matchAssocDiaSources=matchDiaSources.merge( - assocDiaSources.reset_index(drop=True), - on="diaSourceId", - how="left", - suffixes=('_ssi', '_diaSrc') + matchAssocDiaSources=join( + matchDiaSources, + assocDiaSources, + keys="diaSourceId", + join_type="left", + table_names=("ssi", "diaSrc"), + uniq_col_name="{col_name}_{table_name}", ) ) diff --git a/tests/test_matchSourceInjected.py b/tests/test_matchSourceInjected.py index 9d2ee58f6..7f7105a1f 100644 --- a/tests/test_matchSourceInjected.py +++ b/tests/test_matchSourceInjected.py @@ -22,7 +22,6 @@ # import numpy as np -import pandas as pd import unittest from astropy.table import Table @@ -106,7 +105,7 @@ def setUp(self): ) # only 4 injected sources are associated - self.assocDiaSources = pd.DataFrame( + self.assocDiaSources = Table( { "diaSourceId": [101, 102, 103, 201, 202, 205, 207], "band": np.repeat("r", 7),