Skip to content

Modules

A collection of utility functions for data manipulation.

This module contains a collection of utility functions for astronomical data manipulation.

catalog_quest(name, service='Vizier')

Fetch a catalog from a given astronomical service e.g. VizieR, Simbad.

:param name: The name of the catalog to be fetched. :type name: str

:param service: The name of the astronomical service to be used. :type service: str

:return: A pandas DataFrame containing the fetched catalog. :rtype: pd.DataFrame

:raises _UnsupportedServiceError: If an unsupported service is provided.

Source code in rgc/utils/data.py
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
def catalog_quest(name: str, service: str = "Vizier") -> pd.DataFrame:
    """
    Fetch a catalog from a given astronomical service e.g. VizieR, Simbad.

    :param name: The name of the catalog to be fetched.
    :type name: str

    :param service: The name of the astronomical service to be used.
    :type service: str

    :return: A pandas DataFrame containing the fetched catalog.
    :rtype: pd.DataFrame

    :raises _UnsupportedServiceError: If an unsupported service is provided.
    """
    if service == "Vizier":
        Vizier.ROW_LIMIT = -1
        catalog = Vizier.get_catalogs(name)
        return cast(pd.DataFrame, catalog[0].to_pandas())
    else:
        raise _UnsupportedServiceError()

celestial_capture(survey, ra, dec, filename)

Capture a celestial image using the SkyView service.

:param survey: The name of the survey to be used e.g. 'VLA FIRST (1.4 GHz)'. :type survey: str

:param ra: The right ascension of the celestial object. :type ra: Skycoord

:param dec: The declination of the celestial object. :type dec: Skycoord

:param filename: The name of the file to save the image. :type filename: str

Source code in rgc/utils/data.py
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
def celestial_capture(survey: str, ra: float, dec: float, filename: str) -> None:
    """
    Capture a celestial image using the SkyView service.

    :param survey: The name of the survey to be used e.g. 'VLA FIRST (1.4 GHz)'.
    :type survey: str

    :param ra: The right ascension of the celestial object.
    :type ra: Skycoord

    :param dec: The declination of the celestial object.
    :type dec: Skycoord

    :param filename: The name of the file to save the image.
    :type filename: str
    """
    image = SkyView.get_images(position=f"{ra}, {dec}", survey=survey, coordinates="J2000", pixels=(150, 150))[0]

    comment = str(image[0].header["COMMENT"])
    comment = comment.replace("\n", " ")
    comment = comment.replace("\t", " ")

    image[0].header.remove("comment", comment, True)
    image[0].header.add_comment(comment)

    folder_path = Path(filename).parent
    Path(folder_path).mkdir(parents=True, exist_ok=True)
    image.writeto(filename, overwrite=True)

celestial_capture_bulk(catalog, survey, img_dir, classes=None, cls_col=None)

Capture celestial images for a catalog of celestial objects.

:param catalog: A pandas DataFrame containing the catalog of celestial objects. :type catalog: pd.DataFrame

:param survey: The name of the survey to be used e.g. 'VLA FIRST (1.4 GHz)'. :type survey: str

:param img_dir: The path to the directory to save the images. :type img_dir: str

:param classes: A dictionary containing the classes of the celestial objects. :type classes: dict

:param cls_col: The name of the column containing the class labels.

:raises _InvalidCoordinatesError: If coordinates are invalid.

Source code in rgc/utils/data.py
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
def celestial_capture_bulk(
    catalog: pd.DataFrame, survey: str, img_dir: str, classes: Optional[dict] = None, cls_col: Optional[str] = None
) -> None:
    """
    Capture celestial images for a catalog of celestial objects.

    :param catalog: A pandas DataFrame containing the catalog of celestial objects.
    :type catalog: pd.DataFrame

    :param survey: The name of the survey to be used e.g. 'VLA FIRST (1.4 GHz)'.
    :type survey: str

    :param img_dir: The path to the directory to save the images.
    :type img_dir: str

    :param classes: A dictionary containing the classes of the celestial objects.
    :type classes: dict

    :param cls_col: The name of the column containing the class labels.

    :raises _InvalidCoordinatesError: If coordinates are invalid.
    """
    failed = pd.DataFrame(columns=catalog.columns)
    for _, entry in catalog.iterrows():
        try:
            tag = celestial_tag(entry)
            coordinate = SkyCoord(tag, unit=(u.hourangle, u.deg))

            right_ascension = coordinate.ra.deg
            declination = coordinate.dec.deg

            label = _get_class_labels(entry, classes, cls_col) if classes is not None and cls_col is not None else ""

            if "filename" in catalog.columns:
                filename = f'{img_dir}/{label}_{entry["filename"]}.fits'
            else:
                filename = f"{img_dir}/{label}_{tag}.fits"

            celestial_capture(survey, right_ascension, declination, filename)
        except Exception as err:
            series = entry.to_frame().T
            failed = pd.concat([failed, series], ignore_index=True)
            print(f"Failed to capture image. {err}")

celestial_tag(entry)

Generate a name tag for a celestial object based on its coordinates.

:param entry: A pandas Series entry of the catalog. :type entry: pd.Series

:return: A string containing the name tag. :rtype: str

:raises _NoValidCelestialCoordinatesError: If no valid celestial coordinates are found in the entry.

Source code in rgc/utils/data.py
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
def celestial_tag(entry: pd.Series) -> str:
    """
    Generate a name tag for a celestial object based on its coordinates.

    :param entry: A pandas Series entry of the catalog.
    :type entry: pd.Series

    :return: A string containing the name tag.
    :rtype: str

    :raises _NoValidCelestialCoordinatesError: If no valid celestial coordinates are found in the entry.
    """

    def format_dec(dec: str) -> str:
        sign = "+" if float(dec.replace(" ", "")) > 0 else ""
        return f"{sign}{dec}"

    if {"RAJ2000", "DEJ2000"}.issubset(entry.index):
        ra, dec = entry["RAJ2000"], entry["DEJ2000"]
    elif {"RA", "DEC"}.issubset(entry.index):
        ra, dec = entry["RA"], entry["DEC"]
    elif "filename" in entry.index:
        return f"{entry['filename']}"
    elif "FCG" in entry.index:
        return f"{entry['FCG']}"
    else:
        raise _NoValidCelestialCoordinatesError()

    return f"{ra}{format_dec(dec)}"

compute_mean_std(dataloader)

Compute the mean and standard deviation of the dataset.

:param dataloader: The dataloader for the dataset. :type dataloader: torch.utils.data.DataLoader

:return: The mean and standard deviation of the dataset. :rtype: tuple[torch.Tensor, torch.Tensor]

Source code in rgc/utils/data.py
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
def compute_mean_std(dataloader: torch.utils.data.DataLoader) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Compute the mean and standard deviation of the dataset.

    :param dataloader: The dataloader for the dataset.
    :type dataloader: torch.utils.data.DataLoader

    :return: The mean and standard deviation of the dataset.
    :rtype: tuple[torch.Tensor, torch.Tensor]
    """
    data = torch.tensor([])
    for batch in dataloader:
        data = torch.cat((data, batch[0]), 0)

    mean = torch.mean(data, dim=(0, 2, 3))
    std = torch.std(data, dim=(0, 2, 3))

    return mean, std

dataframe_to_html(catalog, save_dir)

Save the catalog as an HTML file.

:param catalog: Catalog of the astronomical objects :type catalog: pd.DataFrame :param save_dir: Path to the directory to save the HTML file :type save_dir: str

Source code in rgc/utils/data.py
372
373
374
375
376
377
378
379
380
381
382
def dataframe_to_html(catalog: pd.DataFrame, save_dir: str) -> None:
    """
    Save the catalog as an HTML file.

    :param catalog: Catalog of the astronomical objects
    :type catalog: pd.DataFrame
    :param save_dir: Path to the directory to save the HTML file
    :type save_dir: str
    """
    Path(save_dir).mkdir(parents=True, exist_ok=True)
    catalog.to_html(os.path.join(save_dir, "catalog.html"))

fits_to_png(fits_file, img_size=None)

Convert a FITS file to a PNG image.

:param fits_file: The path to the FITS file. :type fits_file: str

:param img_size: The size of the output image. :type img_size: Optional[tuple[int, int]]

:return: A PIL Image object containing the PNG image. :rtype: Image.Image

:raises _FileNotFoundError: If the FITS file is not found.

Source code in rgc/utils/data.py
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
def fits_to_png(fits_file: str, img_size: Optional[tuple[int, int]] = None) -> Image.Image:
    """
    Convert a FITS file to a PNG image.

    :param fits_file: The path to the FITS file.
    :type fits_file: str

    :param img_size: The size of the output image.
    :type img_size: Optional[tuple[int, int]]

    :return: A PIL Image object containing the PNG image.
    :rtype: Image.Image

    :raises _FileNotFoundError: If the FITS file is not found.
    """
    try:
        image = fits.getdata(fits_file)
        header = fits.getheader(fits_file)
    except FileNotFoundError as err:
        raise _FileNotFoundError(fits_file) from err

    if img_size is not None:
        width, height = img_size
    else:
        width, height = header["NAXIS1"], header["NAXIS2"]

    image = np.reshape(image, (height, width))
    image[np.isnan(image)] = np.nanmin(image)

    image = (image - np.nanmin(image)) / (np.nanmax(image) - np.nanmin(image)) * 255
    image = image.astype(np.uint8)
    image = Image.fromarray(image, mode="L")

    return cast(Image.Image, image)

fits_to_png_bulk(fits_dir, png_dir, img_size=None)

Convert a directory of FITS files to PNG images.

:param fits_dir: The path to the directory containing the FITS files. :type fits_dir: str

:param png_dir: The path to the directory to save the PNG images. :type png_dir: str

:param img_size: The size of the output image. :type img_size: Optional[tuple[int, int]]

Source code in rgc/utils/data.py
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
def fits_to_png_bulk(fits_dir: str, png_dir: str, img_size: Optional[tuple[int, int]] = None) -> None:
    """
    Convert a directory of FITS files to PNG images.

    :param fits_dir: The path to the directory containing the FITS files.
    :type fits_dir: str

    :param png_dir: The path to the directory to save the PNG images.
    :type png_dir: str

    :param img_size: The size of the output image.
    :type img_size: Optional[tuple[int, int]]
    """
    fits_files = Path(fits_dir).rglob("*.fits")
    for fits_file in fits_files:
        image = fits_to_png(str(fits_file), img_size)

        png_file = os.path.join(png_dir, fits_file.stem)
        Path(png_file).parent.mkdir(parents=True, exist_ok=True)

        if image is not None:
            image.save(png_file)

generate_mask(image_path, mask_dir, freq, beam, dilation, threshold_pixel=5.0, threshold_island=3.0)

Detect sources in the image and generate a mask.

:param image_path: Path to the image file :type image_path: str

:param mask_dir: Path to the directory to save the mask :type mask_dir: str

:param freq: Frequency of the image in MHz :type freq: float

:param beam: Beam size of the image in arcsec :type beam: tuple

:param dilation: Dilation factor for the mask :type dilation: int

:param threshold_pixel: Threshold for island peak in number of sigma above the mean :type threshold_pixel: float

:param threshold_island: Threshold for island detection in number of sigma above the mean :type threshold_island: float

Source code in rgc/utils/data.py
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
def generate_mask(
    image_path: str,
    mask_dir: str,
    freq: float,
    beam: tuple[float, float, float],
    dilation: int,
    threshold_pixel: float = 5.0,
    threshold_island: float = 3.0,
) -> None:
    """
    Detect sources in the image and generate a mask.

    :param image_path: Path to the image file
    :type image_path: str

    :param mask_dir: Path to the directory to save the mask
    :type mask_dir: str

    :param freq: Frequency of the image in MHz
    :type freq: float

    :param beam: Beam size of the image in arcsec
    :type beam: tuple

    :param dilation: Dilation factor for the mask
    :type dilation: int

    :param threshold_pixel: Threshold for island peak in number of sigma above the mean
    :type threshold_pixel: float

    :param threshold_island: Threshold for island detection in number of sigma above the mean
    :type threshold_island: float
    """
    try:
        image = bdsf.process_image(
            image_path,
            beam=beam,
            thresh_isl=threshold_island,
            thresh_pix=threshold_pixel,
            frequency=freq,
        )

        mask_file = Path(mask_dir) / Path(image_path).name
        Path(mask_file).parent.mkdir(parents=True, exist_ok=True)

        image.export_image(
            img_type="island_mask",
            outfile=mask_file,
            clobber=True,
            mask_dilation=dilation,
        )

    except Exception:
        print("Failed to generate mask.")
        return None

generate_mask_bulk(catalog, img_dir, mask_dir, freq, beam)

Generate masks for a catalog of celestial objects.

:param catalog: A pandas DataFrame containing the catalog of celestial objects. :type catalog: pd.DataFrame

:param img_dir: The path to the directory containing the images. :type img_dir: str

:param mask_dir: The path to the directory to save the masks. :type mask_dir: str

:param freq: Frequency of the image in MHz :type freq: float

:param beam: Beam size of the image in arcsec :type beam: tuple

Source code in rgc/utils/data.py
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
def generate_mask_bulk(
    catalog: pd.DataFrame, img_dir: str, mask_dir: str, freq: float, beam: tuple[float, float, float]
) -> None:
    """
    Generate masks for a catalog of celestial objects.

    :param catalog: A pandas DataFrame containing the catalog of celestial objects.
    :type catalog: pd.DataFrame

    :param img_dir: The path to the directory containing the images.
    :type img_dir: str

    :param mask_dir: The path to the directory to save the masks.
    :type mask_dir: str

    :param freq: Frequency of the image in MHz
    :type freq: float

    :param beam: Beam size of the image in arcsec
    :type beam: tuple
    """
    for _, entry in catalog.iterrows():
        try:
            filename = entry["filename"]
            image_path = os.path.join(img_dir, f"{filename}.fits")
            dilation = entry["dilation"]
            threshold_pixel = entry["background sigma"]
            threshold_island = entry["foreground sigma"]

            generate_mask(
                image_path,
                mask_dir,
                freq,
                beam,
                dilation,
                threshold_pixel,
                threshold_island,
            )

        except Exception as err:
            print(f"Failed to generate mask. {err}")
            return None

mask_image(image, mask)

Mask an image with a given mask image.

:param image: The image to be masked. :type image: Image.Image

:param mask: The mask image. :type mask: Image.Image

:return: A PIL Image object containing the masked image. :rtype: Image.Image

:raises _ImageMaskDimensionError: If the dimensions of the image and mask do not match.

Source code in rgc/utils/data.py
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
def mask_image(image: Image.Image, mask: Image.Image) -> Image.Image:
    """
    Mask an image with a given mask image.

    :param image: The image to be masked.
    :type image: Image.Image

    :param mask: The mask image.
    :type mask: Image.Image

    :return: A PIL Image object containing the masked image.
    :rtype: Image.Image

    :raises _ImageMaskDimensionError: If the dimensions of the image and mask do not match.
    """
    image_array = np.array(image)
    mask_array = np.array(mask)

    if image_array.shape != mask_array.shape:
        raise _ImageMaskDimensionError()

    masked_array = np.where(mask_array == 0, 0, image_array)
    masked_image = Image.fromarray(masked_array, mode="L")

    return cast(Image.Image, masked_image)

mask_image_bulk(image_dir, mask_dir, masked_dir)

Mask a directory of images with a directory of mask images.

:param image_dir: The path to the directory containing the images. :type image_dir: str

:param mask_dir: The path to the directory containing the mask images. :type mask_dir: str

:param masked_dir: The path to the directory to save the masked images. :type masked_dir: str

:raises _FileNotFoundError: If no images or masks are found in the directories. :raises _ImageMaskCountMismatchError: If the number of images and masks do not match.

Source code in rgc/utils/data.py
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
def mask_image_bulk(image_dir: str, mask_dir: str, masked_dir: str) -> None:
    """
    Mask a directory of images with a directory of mask images.

    :param image_dir: The path to the directory containing the images.
    :type image_dir: str

    :param mask_dir: The path to the directory containing the mask images.
    :type mask_dir: str

    :param masked_dir: The path to the directory to save the masked images.
    :type masked_dir: str

    :raises _FileNotFoundError: If no images or masks are found in the directories.
    :raises _ImageMaskCountMismatchError: If the number of images and masks do not match.
    """
    image_paths = sorted(Path(image_dir).glob("*.png"))
    mask_paths = sorted(Path(mask_dir).glob("*.png"))

    if len(image_paths) == 0 or len(mask_paths) == 0:
        raise _FileNotFoundError()

    if len(image_paths) != len(mask_paths):
        raise _ImageMaskCountMismatchError() from None

    os.makedirs(masked_dir, exist_ok=True)

    for image_path in image_paths:
        mask_path = Path(mask_dir) / image_path.name

        if not mask_path.exists():
            print(f"Skipping {image_path.name} due to missing mask.")
            continue

        image = Image.open(image_path)
        mask = Image.open(mask_path)

        if image.size != mask.size:
            print(f"Skipping {image_path.name} due to mismatched dimensions.")
            continue
        else:
            masked_image = mask_image(image, mask)

        masked_image.save(Path(masked_dir) / image_path.name)

remove_artifacts(folder, extension)

Remove files with the given extensions from a folder.

:param folder: Path to the folder to clear :type folder: str :param extension: List of file with the given extensions to keep :type extension: list

Source code in rgc/utils/data.py
405
406
407
408
409
410
411
412
413
414
415
416
417
418
def remove_artifacts(folder: str, extension: list[str]) -> None:
    """
    Remove files with the given extensions from a folder.

    :param folder: Path to the folder to clear
    :type folder: str
    :param extension: List of file with the given extensions to keep
    :type extension: list
    """
    for file in os.listdir(folder):
        if not file.endswith(tuple(extension)):
            os.remove(os.path.join(folder, file))

    print(f"Artifacts removed from {folder} with extensions {', '.join(extension)}")