Coverage for pds_crawler/utils.py: 81%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# -*- coding: utf-8 -*-
2# pds-crawler - ETL to index PDS data to pdssp
3# Copyright (C) 2023 - CNES (Jean-Christophe Malapert for Pôle Surfaces Planétaires)
4# This file is part of pds-crawler <https://github.com/pdssp/pds_crawler>
5# SPDX-License-Identifier: LGPL-3.0-or-later
6import concurrent.futures
7import logging
8import os
9import time
10import tracemalloc
11from datetime import datetime
12from enum import Enum
13from functools import partial
14from functools import wraps
15from pathlib import Path
16from typing import Any
17from typing import cast
18from typing import Dict
19from typing import Iterable
20from typing import List
21from typing import Union
22from urllib.parse import parse_qs
23from urllib.parse import urlparse
25import requests
26from bs4 import BeautifulSoup
27from bs4 import Tag
28from fastnumbers import float as ffloat
29from fastnumbers import int as iint
30from fastnumbers import isfloat
31from fastnumbers import isint
32from requests.adapters import HTTPAdapter
33from tqdm import tqdm
34from urllib3 import Retry
36from .exception import DateConversionError
38logger = logging.getLogger(__name__)
39requests.urllib3.disable_warnings( # type: ignore
40 requests.urllib3.exceptions.InsecureRequestWarning # type: ignore
41)
44class DocEnum(Enum):
45 """Enum where we can add documentation."""
47 def __new__(cls, value, doc=None):
48 self = object.__new__(
49 cls
50 ) # calling super().__new__(value) here would fail
51 self._value_ = value
52 if doc is not None:
53 self.__doc__ = doc
54 return self
57class UtilsMath:
58 """
59 The UtilsMath class provides some utility functions for working with data types:
61 - is_integer: determines whether a given value is an integer or not.
62 - is_float: determines whether a given value is a float or not.
63 - is_bool: determines whether a given value is a boolean or not.
64 - convert_dt: attempts to convert a given string value to an appropriate data type (integer, float, boolean or string) if possible.
65 """
67 @staticmethod
68 def is_integer(value: str) -> bool:
69 """Determines whether the given value is an integer or not.
70 Args:
71 value (str): Value to check.
72 Returns:
73 bool: True if the value is an integer, False otherwise.
74 """
75 return isint(value)
77 @staticmethod
78 def is_float(value: str) -> bool:
79 """Determines whether the given value is a float or not.
80 Args:
81 value (str): Value to check.
82 Returns:
83 bool: True if the value is a float, False otherwise.
84 """
85 return isfloat(value) and not isint(value)
87 @staticmethod
88 def is_bool(value: str) -> bool:
89 """Determines whether the given value is a boolean or not.
90 Args:
91 value (str): Value to check.
92 Returns:
93 bool: True if the value is a boolean, False otherwise.
94 """
95 if not isinstance(value, str):
96 return False
97 return value.lower() in [
98 "true",
99 "t",
100 "y",
101 "yes",
102 "false",
103 "f",
104 "n",
105 "no",
106 ]
108 @staticmethod
109 def convert_dt(value: str) -> Any:
110 """Converts the given string value to the appropriate data type if possible.
111 Args:
112 value (str): Value to convert.
113 Returns:
114 Any: The converted value.
115 """
116 result: Any
117 if not isinstance(value, str):
118 result = value
119 elif isint(value):
120 result = int(value)
121 elif isfloat(value):
122 result = ffloat(value)
123 elif UtilsMath.is_bool(value):
124 result = value.lower() in ("yes", "true", "t")
125 else:
126 result = value
127 return result
130def cache_download(func):
131 """Decorator to check if the download has been previously done and avoid redownloading.
133 This decorator checks if a file has already been downloaded before calling the `parallel_requests` function.
134 If the file has been downloaded previously, it is not downloaded again, and the cached file is used instead.
135 If the file has not been downloaded, the function calls `parallel_requests` to download the file.
137 Args:
138 func (callable): The function to be decorated. It should be `parallel_requests`.
140 Returns:
141 callable: A decorated function.
143 Raises:
144 NotImplementedError: If the function being decorated is not `parallel_requests`.
145 """
147 @wraps(func)
148 def cache_download_wrapper(*args, **kwargs):
149 """Wrapper function that checks for cached downloads before downloading.
151 This wrapper function checks if the URL of each file in the input `urls` list has been downloaded
152 previously and skips the download if the file is found in the cache. If the file has not been
153 downloaded, the function calls `parallel_requests` to download the file.
155 Args:
156 *args: Arguments passed to the decorated function.
157 **kwargs: Keyword arguments passed to the decorated function.
159 Returns:
160 The result of the decorated function.
162 Raises:
163 NotImplementedError: If the function being decorated is not `parallel_requests`.
164 """
165 # Ensure that the function being wrapped is 'parallel_requests'
166 if func.__name__ != "parallel_requests":
167 raise NotImplementedError()
169 # Get the directory and list of URLs to be downloaded
170 directory: str = args[0]
171 urls: List[str] = args[1]
172 urls_copy: List[str] = urls.copy()
174 # Check if each URL has already been downloaded and remove from urls_copy if so
175 for url in urls:
176 filepath: str = compute_downloaded_filepath(directory, url)
177 if os.path.exists(filepath):
178 logger.warning(f"file {filepath} in cache, skip the download")
179 urls_copy.remove(url)
180 else:
181 logger.info(f"Downloading {url} in progress")
183 # Call the original function with the new list of URLs
184 if len(args) == 3:
185 new_args = (args[0], urls_copy, args[2])
186 else:
187 new_args = (args[0], urls_copy)
188 result = func(*new_args, **kwargs)
189 return result
191 return cache_download_wrapper
194def requests_retry_session(
195 retries=3,
196 backoff_factor=3,
197 status_forcelist=(500, 502, 504),
198 session=None,
199) -> requests.Session:
200 """Requests with retry
202 A backoff factor to apply between attempts after the second try (most errors are resolved immediately by a
203 second try without a delay). The request will sleep for
204 {backoff factor} * (2 ** ({number of total retries} - 1))
206 Args:
207 retries (int, optional): number of retries. Defaults to 3.
208 backoff_factor (int, optional): backoff factor. Defaults to 3.
209 status_forcelist (tuple, optional): status for which the retry must be done. Defaults to (500, 502, 504).
210 session (Session, optional): http/https session. Defaults to None.
212 Returns:
213 requests.Session: session
214 """
215 session = session or requests.Session()
216 retry = Retry(
217 total=retries,
218 read=retries,
219 connect=retries,
220 backoff_factor=backoff_factor,
221 status_forcelist=status_forcelist,
222 )
223 adapter = HTTPAdapter(max_retries=retry)
224 session.mount("http://", adapter)
225 session.mount("https://", adapter)
226 return session
229def simple_download(url: str, filepath: str, timeout):
230 """Downloads the contents of the given URL and saves it to a file.
232 Args:
233 - url (str): The URL to download.
234 - filepath (str): The file path to save the downloaded contents.
235 - timeout: The maximum number of seconds to wait for a response from the server.
236 """
237 # Send a GET request to the URL with the given timeout.
238 response = requests.get(
239 url, allow_redirects=True, verify=False, timeout=timeout
240 )
242 # If the response status code is 200 (OK), save the contents to a file.
243 if response.status_code == 200:
244 # Check if the response content type is HTML.
245 if "text/html" in response.headers.get("content-type", ""):
246 # If the content type is HTML, check if the response contains a "refresh" meta tag.
247 soup = BeautifulSoup(response.content, "html.parser")
248 redirect_elt = soup.find(
249 "meta", attrs={"http-equiv": "refresh", "content": True}
250 )
251 if redirect_elt is not None:
252 # If a "refresh" tag is found, extract the redirect URL and send another GET request to it.
253 redirect_tag = cast(Tag, redirect_elt)
254 redirect_tag_value: str = cast(str, redirect_tag["content"])
255 redirect_url = (
256 redirect_tag_value.split(";")[1].strip().split("=")[1]
257 )
258 response = requests.get(
259 redirect_url,
260 allow_redirects=True,
261 verify=False,
262 timeout=timeout,
263 )
264 # Write the response content to the given file path.
265 outfile: Path = Path(filepath)
266 outfile.write_bytes(response.content)
267 else:
268 logger.error(
269 f"The request {url} has failed with the error code: {response.status_code}"
270 )
273@cache_download
274def parallel_requests(
275 directory: str,
276 urls: List[str],
277 nb_workers: int = 3,
278 timeout=180,
279 time_sleep=2,
280 progress_bar=False,
281):
282 """Download files from a list of URLs using a ThreadPoolExecutor with a given number of workers.
284 Args:
285 - directory (str): the directory where to save the downloaded files
286 - urls (List[str]): a list of URLs to download
287 - nb_workers (int): the number of workers for the ThreadPoolExecutor
288 - timeout (int): the maximum time to wait for a response from the server, in seconds
289 - time_sleep (int): the time to sleep between two requests, in seconds
290 - progress_bar (bool): whether to show a progress bar or not
292 Raises:
293 - requests.exceptions.ConnectionError: if a connection error occurs while downloading a file
294 """
296 def scrape(url):
297 """Download a file from a URL.
299 Args:
300 - url (str): the URL to download
302 Returns:
303 - url (str): the URL that has been downloaded
304 """
305 start = time.time()
306 filepath: str = compute_downloaded_filepath(directory, url)
307 with requests_retry_session():
308 simple_download(url, filepath, timeout)
309 time.sleep(time_sleep)
310 end = time.time()
311 hours, rem = divmod(end - start, 3600)
312 minutes, seconds = divmod(rem, 60)
313 file_size_bytes = os.path.getsize(filepath)
314 file_size_mb = file_size_bytes / 1024**2
315 ProgressLogger.write(
316 f"{url} downloaded ({file_size_mb:0.03f} MB) in {int(hours):0>2}:{int(minutes):0>2}:{seconds:05.2f}",
317 logger,
318 )
319 return url
321 if len(urls) == 0:
322 return
324 if not os.path.exists(directory):
325 os.makedirs(directory, exist_ok=True)
327 progress_logger = ProgressLogger(
328 total=len(urls),
329 iterable=None,
330 logger=logger,
331 description="Downloading data",
332 position=1,
333 leave=False,
334 disable_tqdm=not progress_bar,
335 )
337 with concurrent.futures.ThreadPoolExecutor(
338 max_workers=nb_workers
339 ) as executor:
340 futures = [executor.submit(scrape, url) for url in urls]
341 for future in concurrent.futures.as_completed(futures):
342 try:
343 future.result()
344 except requests.exceptions.ConnectionError as err:
345 logger.exception(f"[parallel_requests]: {err}")
346 progress_logger.update(n=1)
348 progress_logger.close()
351def compute_downloaded_filepath(directory: str, url: str) -> str:
352 """Computes the file path where a downloaded file will be saved based on the provided URL and directory.
354 Args:
355 directory (str): The directory where the downloaded file will be saved.
356 url (str): The URL of the downloaded file.
358 Returns:
359 str: The file path where the downloaded file will be saved.
360 """
361 # Parse the URL to extract any query parameters
362 parsed_url = urlparse(url)
363 params: Dict[str, str] = parse_qs(parsed_url.query) # type: ignore
365 # Generate the filename based on the query parameters or the URL path
366 filename: str
367 if "ihid" in params:
368 # If the URL contains "ihid" parameter, create a filename using "target", "ihid", "iid", "pt", and "offset" parameters
369 filename = f"{params['target'][0]}_{params['ihid'][0]}_{params['iid'][0]}_{params['pt'][0]}_{params['offset'][0]}.json"
370 filename = filename.replace(os.path.sep, "_")
371 else:
372 # If the URL doesn't contain "ihid" parameter, create a filename using the URL path
373 path: str = parsed_url.path
374 filename: str = os.path.basename(path).lower()
376 # Return the full file path where the downloaded file will be saved
377 return os.path.join(directory, filename)
380def compute_download_directory_path(
381 directory: str, target: str, ihid: str, iid: str, pt: str, ds: str
382) -> str:
383 """Computes the path where the file is downloaded based on a base directory and a metadata coming from the PDS.
385 Args:
386 directory (str): base direcory
387 target (str): solar body
388 ihid (str): plateform
389 iid (str): instrument
390 pt (str): product type
391 ds (str): collection
393 Returns:
394 str: the path of the directory
395 """
396 # Create a list with the names of the five required folders/files
397 items = [
398 item.lower().replace(
399 os.path.sep, "_"
400 ) # Replace occurrences of os.path.sep with underscores in each name
401 for item in [
402 target,
403 ihid,
404 iid,
405 pt,
406 ds,
407 ] # For each required name, make a lowercase copy
408 ]
409 # Join all required names using os.path.sep as path separator and add the base path
410 return os.path.join(directory, os.path.sep.join(items))
413def utc_to_iso(utc_time: str, timespec: str = "auto") -> str:
414 """Convert UTC time string to ISO format string (STAC standard)."""
415 # set valid datatime formats 2018-08-23T23:24:36.865Z
416 valid_formats = [
417 "%Y-%m-%dT%H:%M:%S",
418 "%Y-%m-%dT%H:%M:%S.%f",
419 "%Y-%m-%dT%H:%M:%S.%fZ",
420 "%Y-%m-%dT%H:%M:%SZ",
421 "%Y-%m-%d",
422 ]
423 for valid_format in valid_formats:
424 try:
425 return datetime.strptime(utc_time, valid_format).isoformat(
426 timespec=timespec
427 )
428 except: # noqa: E722
429 continue
430 raise DateConversionError(
431 f"Cannot convert in ISO str this time {utc_time} with the following patterns {valid_formats}"
432 )
435class Observable:
436 """Observable"""
438 def __init__(self):
439 """Init the observable"""
440 self._observers = list()
442 def subscribe(self, observer):
443 """Subscribe the observable to the observer
445 Args:
446 observer (Observer): Observer
447 """
448 self._observers.append(observer)
450 def notify_observers(self, *args, **kwargs):
451 """Notify the observers"""
452 for obs in self._observers:
453 obs.notify(self, *args, **kwargs)
455 def unsubscribe(self, observer):
456 """Unsubscribe the observers
458 Args:
459 observer (Observer): Observer
460 """
461 self._observers.remove(observer)
463 def unsubscribe_all(self):
464 """Unsubscribe all observers."""
465 self._observers.clear()
468class Observer: # pylint: disable=R0903
469 """Observer"""
471 def __init__(self, observable):
472 """Init the observer
474 Args:
475 observable (Observable): Observable to observe
476 """
477 observable.subscribe(self)
479 def notify(self, observable, *args, **kwargs): # pylint: disable=R0201
480 """Notify
482 Args:
483 observable (Observable): Observable
484 """
485 print("Got", args, kwargs, "From", observable)
488class UtilsMonitoring: # noqa: R0205
489 """Some Utilities."""
491 # pylint: disable:invalid_name
492 @staticmethod
493 def io_display(
494 func=None, input=True, output=True, level=15
495 ): # pylint: disable=W0622
496 """Monitor the input/output of a function.
498 NB : Do not use this monitoring method on an __init__ if the class
499 implements __repr__ with attributes
501 Parameters
502 ----------
503 func: func
504 function to monitor (default: {None})
505 input: bool
506 True when the function must monitor the input (default: {True})
507 output: bool
508 True when the function must monitor the output (default: {True})
509 level: int
510 Level from which the function must log
511 Returns
512 -------
513 object : the result of the function
514 """
515 if func is None:
516 return partial(
517 UtilsMonitoring.io_display,
518 input=input,
519 output=output,
520 level=level,
521 )
523 @wraps(func)
524 def wrapped(*args, **kwargs):
525 name = func.__qualname__
526 logger = logging.getLogger(__name__ + "." + name)
528 if input and logger.getEffectiveLevel() >= level:
529 msg = f"[{name}] Entering '{name}' (args={args}, kwargs={kwargs})"
530 logger.log(level, msg)
532 result = func(*args, **kwargs)
534 if output and logger.getEffectiveLevel() >= level:
535 msg = f"[{name}] Exiting '{name}' (result={result})"
536 logger.log(level, msg)
538 return result
540 return wrapped
542 @staticmethod
543 def timeit(func):
544 """Decorator to measure the time spent in an function"""
546 @wraps(func)
547 def timeit_wrapper(*args, **kwargs):
548 start_time = time.perf_counter()
549 result = func(*args, **kwargs)
550 end_time = time.perf_counter()
551 total_time = end_time - start_time
552 logger.info(
553 f"Function {func.__name__}{args} {kwargs} Took {total_time:.4f} seconds"
554 )
555 return result
557 return timeit_wrapper
559 @staticmethod
560 def measure_memory(func=None, level=logging.DEBUG):
561 """Measure the memory of the function
563 Args:
564 func (func, optional): Function to measure. Defaults to None.
565 level (int, optional): Level of the log. Defaults to logging.INFO.
567 Returns:
568 object : the result of the function
569 """
570 if func is None:
571 return partial(UtilsMonitoring.measure_memory, level=level)
573 @wraps(func)
574 def newfunc(*args, **kwargs):
575 name = func.__qualname__
576 logger = logging.getLogger(__name__ + "." + name)
577 tracemalloc.start()
578 result = func(*args, **kwargs)
579 current, peak = tracemalloc.get_traced_memory()
580 msg = f"""
581 \033[37mFunction Name :\033[35;1m {func.__name__}\033[0m
582 \033[37mCurrent memory usage:\033[36m {current / 10 ** 6}MB\033[0m
583 \033[37mPeak :\033[36m {peak / 10 ** 6}MB\033[0m
584 """
585 logger.log(level, msg)
586 tracemalloc.stop()
587 return result
589 return newfunc
592class ProgressLogger:
593 """A progress logger that can be used with or without tqdm."""
595 def __init__(
596 self,
597 total: int,
598 logger: logging.Logger,
599 iterable: Union[Iterable, None] = None,
600 description: str = "processing",
601 disable_tqdm=False,
602 *args,
603 **kwargs,
604 ):
605 """A progress logger that can be used with or without tqdm.
607 Args:
608 total (int): The total number of items to be processed.
609 logger (logging.Logger): The logger to use for progress updates.
610 iterable (Union[Iterable, None]): The iterable to be processed.
611 description (str): The description of the progress bar.
612 disable_tqdm (bool): Whether to disable tqdm progress bar.
613 *args: Additional positional arguments to be passed to tqdm.
614 **kwargs: Additional keyword arguments to be passed to tqdm.
616 Returns:
617 None
618 """
619 self.total = total
620 self.disable_tqdm = disable_tqdm
621 self.description = description
622 self.logger = logger
623 self.iterable: Union[Iterable[str], None] = iterable
624 self.nb: int = 0
625 self.kwargs = kwargs
626 if self.iterable is None:
627 self.pbar = tqdm(
628 total=total,
629 desc=description,
630 disable=self.disable_tqdm,
631 **self.kwargs,
632 )
633 self._send_message()
635 def _send_message(self):
636 """Sends progress update messages to the logger at regular intervals."""
637 if self.disable_tqdm and self.nb % 10 == 0:
638 msg = f"{self.description} : {int(self.nb/self.total)}%"
639 self.logger.info(msg)
641 def __enter__(self):
642 """Called when the 'with' statement is entered. Initializes the progress bar.
644 Returns:
645 self: The ProgressLogger instance.
646 """
647 self.pbar = tqdm(
648 total=self.total,
649 desc=self.description,
650 disable=self.disable_tqdm,
651 **self.kwargs,
652 )
653 self._send_message()
654 return self
656 def __exit__(self, exc_type, exc_value, traceback):
657 """Called when the 'with' statement is exited. Closes the progress bar.
659 Args:
660 exc_type (type): The type of the exception, if one occurred.
661 exc_value (Exception): The exception object, if one occurred.
662 traceback (traceback): The traceback object, if one occurred.
663 """
664 if self.pbar:
665 self.pbar.close()
667 def __iter__(self):
668 """Iterates over the iterable and updates the progress bar.
670 Args:
671 None
673 Yields:
674 i: The next item in the iterable.
675 """
676 if self.iterable and not self.disable_tqdm:
677 for i in self.iterable:
678 self.pbar.update(1)
679 yield i
680 elif self.iterable:
681 for i in self.iterable:
682 self.nb += 1
683 self._send_message()
684 yield i
685 else:
686 yield None
688 def update(self, n: int):
689 """Updates the progress bar with the specified number of items.
691 Args:
692 n (int): The number of items to update.
693 """
694 if self.disable_tqdm:
695 self.nb += 1
696 self._send_message()
697 else:
698 self.pbar.update(n=n)
700 def write_msg(self, msg: str):
701 """Write a message using tqdm or the logger according if tqdm is used or not."""
702 if self.disable_tqdm:
703 logger.info(msg)
704 else:
705 tqdm.write(msg)
707 @staticmethod
708 def write(msg: str, logger: logging.Logger):
709 """Write a message using the logger."""
710 logger.info(msg)
712 def close(self):
713 """Close the tqdm progress bar."""
714 if self.pbar:
715 self.pbar.close()
718class Locking:
719 """Utility class for locking a file"""
721 @staticmethod
722 def lock_file(file):
723 """Creates a lock file with the same name as the input file, but with ".lock" appended to the end.
724 This method can be used to prevent other processes from accessing the same file simultaneously.
726 Args:
727 file (str): The name of the file to lock.
729 Returns:
730 None
731 """
732 # Create a lock file by appending ".lock" to the input file name
733 lock_file = file + ".lock"
735 # Initialize the locking flag to False
736 locking = False
738 # While the file is not locked, keep trying to create the lock file
739 while not locking:
740 try:
741 # Try to create the lock file using the mkdir method
742 os.mkdir(lock_file)
744 # If successful, set the locking flag to True and exit the loop
745 locking = True
746 except OSError:
747 # If the lock file already exists, wait for 0.1 seconds and try again
748 time.sleep(0.1)
750 @staticmethod
751 def unlock_file(file: str):
752 """Remove the lock file
754 Args:
755 file (str): The name of the file to unlock.
756 """
757 os.rmdir(file + ".lock")