Coverage for pds_crawler/utils.py: 81%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

277 statements  

1# -*- coding: utf-8 -*- 

2# pds-crawler - ETL to index PDS data to pdssp 

3# Copyright (C) 2023 - CNES (Jean-Christophe Malapert for Pôle Surfaces Planétaires) 

4# This file is part of pds-crawler <https://github.com/pdssp/pds_crawler> 

5# SPDX-License-Identifier: LGPL-3.0-or-later 

6import concurrent.futures 

7import logging 

8import os 

9import time 

10import tracemalloc 

11from datetime import datetime 

12from enum import Enum 

13from functools import partial 

14from functools import wraps 

15from pathlib import Path 

16from typing import Any 

17from typing import cast 

18from typing import Dict 

19from typing import Iterable 

20from typing import List 

21from typing import Union 

22from urllib.parse import parse_qs 

23from urllib.parse import urlparse 

24 

25import requests 

26from bs4 import BeautifulSoup 

27from bs4 import Tag 

28from fastnumbers import float as ffloat 

29from fastnumbers import int as iint 

30from fastnumbers import isfloat 

31from fastnumbers import isint 

32from requests.adapters import HTTPAdapter 

33from tqdm import tqdm 

34from urllib3 import Retry 

35 

36from .exception import DateConversionError 

37 

38logger = logging.getLogger(__name__) 

39requests.urllib3.disable_warnings( # type: ignore 

40 requests.urllib3.exceptions.InsecureRequestWarning # type: ignore 

41) 

42 

43 

44class DocEnum(Enum): 

45 """Enum where we can add documentation.""" 

46 

47 def __new__(cls, value, doc=None): 

48 self = object.__new__( 

49 cls 

50 ) # calling super().__new__(value) here would fail 

51 self._value_ = value 

52 if doc is not None: 

53 self.__doc__ = doc 

54 return self 

55 

56 

57class UtilsMath: 

58 """ 

59 The UtilsMath class provides some utility functions for working with data types: 

60 

61 - is_integer: determines whether a given value is an integer or not. 

62 - is_float: determines whether a given value is a float or not. 

63 - is_bool: determines whether a given value is a boolean or not. 

64 - convert_dt: attempts to convert a given string value to an appropriate data type (integer, float, boolean or string) if possible. 

65 """ 

66 

67 @staticmethod 

68 def is_integer(value: str) -> bool: 

69 """Determines whether the given value is an integer or not. 

70 Args: 

71 value (str): Value to check. 

72 Returns: 

73 bool: True if the value is an integer, False otherwise. 

74 """ 

75 return isint(value) 

76 

77 @staticmethod 

78 def is_float(value: str) -> bool: 

79 """Determines whether the given value is a float or not. 

80 Args: 

81 value (str): Value to check. 

82 Returns: 

83 bool: True if the value is a float, False otherwise. 

84 """ 

85 return isfloat(value) and not isint(value) 

86 

87 @staticmethod 

88 def is_bool(value: str) -> bool: 

89 """Determines whether the given value is a boolean or not. 

90 Args: 

91 value (str): Value to check. 

92 Returns: 

93 bool: True if the value is a boolean, False otherwise. 

94 """ 

95 if not isinstance(value, str): 

96 return False 

97 return value.lower() in [ 

98 "true", 

99 "t", 

100 "y", 

101 "yes", 

102 "false", 

103 "f", 

104 "n", 

105 "no", 

106 ] 

107 

108 @staticmethod 

109 def convert_dt(value: str) -> Any: 

110 """Converts the given string value to the appropriate data type if possible. 

111 Args: 

112 value (str): Value to convert. 

113 Returns: 

114 Any: The converted value. 

115 """ 

116 result: Any 

117 if not isinstance(value, str): 

118 result = value 

119 elif isint(value): 

120 result = int(value) 

121 elif isfloat(value): 

122 result = ffloat(value) 

123 elif UtilsMath.is_bool(value): 

124 result = value.lower() in ("yes", "true", "t") 

125 else: 

126 result = value 

127 return result 

128 

129 

130def cache_download(func): 

131 """Decorator to check if the download has been previously done and avoid redownloading. 

132 

133 This decorator checks if a file has already been downloaded before calling the `parallel_requests` function. 

134 If the file has been downloaded previously, it is not downloaded again, and the cached file is used instead. 

135 If the file has not been downloaded, the function calls `parallel_requests` to download the file. 

136 

137 Args: 

138 func (callable): The function to be decorated. It should be `parallel_requests`. 

139 

140 Returns: 

141 callable: A decorated function. 

142 

143 Raises: 

144 NotImplementedError: If the function being decorated is not `parallel_requests`. 

145 """ 

146 

147 @wraps(func) 

148 def cache_download_wrapper(*args, **kwargs): 

149 """Wrapper function that checks for cached downloads before downloading. 

150 

151 This wrapper function checks if the URL of each file in the input `urls` list has been downloaded 

152 previously and skips the download if the file is found in the cache. If the file has not been 

153 downloaded, the function calls `parallel_requests` to download the file. 

154 

155 Args: 

156 *args: Arguments passed to the decorated function. 

157 **kwargs: Keyword arguments passed to the decorated function. 

158 

159 Returns: 

160 The result of the decorated function. 

161 

162 Raises: 

163 NotImplementedError: If the function being decorated is not `parallel_requests`. 

164 """ 

165 # Ensure that the function being wrapped is 'parallel_requests' 

166 if func.__name__ != "parallel_requests": 

167 raise NotImplementedError() 

168 

169 # Get the directory and list of URLs to be downloaded 

170 directory: str = args[0] 

171 urls: List[str] = args[1] 

172 urls_copy: List[str] = urls.copy() 

173 

174 # Check if each URL has already been downloaded and remove from urls_copy if so 

175 for url in urls: 

176 filepath: str = compute_downloaded_filepath(directory, url) 

177 if os.path.exists(filepath): 

178 logger.warning(f"file {filepath} in cache, skip the download") 

179 urls_copy.remove(url) 

180 else: 

181 logger.info(f"Downloading {url} in progress") 

182 

183 # Call the original function with the new list of URLs 

184 if len(args) == 3: 

185 new_args = (args[0], urls_copy, args[2]) 

186 else: 

187 new_args = (args[0], urls_copy) 

188 result = func(*new_args, **kwargs) 

189 return result 

190 

191 return cache_download_wrapper 

192 

193 

194def requests_retry_session( 

195 retries=3, 

196 backoff_factor=3, 

197 status_forcelist=(500, 502, 504), 

198 session=None, 

199) -> requests.Session: 

200 """Requests with retry 

201 

202 A backoff factor to apply between attempts after the second try (most errors are resolved immediately by a 

203 second try without a delay). The request will sleep for 

204 {backoff factor} * (2 ** ({number of total retries} - 1)) 

205 

206 Args: 

207 retries (int, optional): number of retries. Defaults to 3. 

208 backoff_factor (int, optional): backoff factor. Defaults to 3. 

209 status_forcelist (tuple, optional): status for which the retry must be done. Defaults to (500, 502, 504). 

210 session (Session, optional): http/https session. Defaults to None. 

211 

212 Returns: 

213 requests.Session: session 

214 """ 

215 session = session or requests.Session() 

216 retry = Retry( 

217 total=retries, 

218 read=retries, 

219 connect=retries, 

220 backoff_factor=backoff_factor, 

221 status_forcelist=status_forcelist, 

222 ) 

223 adapter = HTTPAdapter(max_retries=retry) 

224 session.mount("http://", adapter) 

225 session.mount("https://", adapter) 

226 return session 

227 

228 

229def simple_download(url: str, filepath: str, timeout): 

230 """Downloads the contents of the given URL and saves it to a file. 

231 

232 Args: 

233 - url (str): The URL to download. 

234 - filepath (str): The file path to save the downloaded contents. 

235 - timeout: The maximum number of seconds to wait for a response from the server. 

236 """ 

237 # Send a GET request to the URL with the given timeout. 

238 response = requests.get( 

239 url, allow_redirects=True, verify=False, timeout=timeout 

240 ) 

241 

242 # If the response status code is 200 (OK), save the contents to a file. 

243 if response.status_code == 200: 

244 # Check if the response content type is HTML. 

245 if "text/html" in response.headers.get("content-type", ""): 

246 # If the content type is HTML, check if the response contains a "refresh" meta tag. 

247 soup = BeautifulSoup(response.content, "html.parser") 

248 redirect_elt = soup.find( 

249 "meta", attrs={"http-equiv": "refresh", "content": True} 

250 ) 

251 if redirect_elt is not None: 

252 # If a "refresh" tag is found, extract the redirect URL and send another GET request to it. 

253 redirect_tag = cast(Tag, redirect_elt) 

254 redirect_tag_value: str = cast(str, redirect_tag["content"]) 

255 redirect_url = ( 

256 redirect_tag_value.split(";")[1].strip().split("=")[1] 

257 ) 

258 response = requests.get( 

259 redirect_url, 

260 allow_redirects=True, 

261 verify=False, 

262 timeout=timeout, 

263 ) 

264 # Write the response content to the given file path. 

265 outfile: Path = Path(filepath) 

266 outfile.write_bytes(response.content) 

267 else: 

268 logger.error( 

269 f"The request {url} has failed with the error code: {response.status_code}" 

270 ) 

271 

272 

273@cache_download 

274def parallel_requests( 

275 directory: str, 

276 urls: List[str], 

277 nb_workers: int = 3, 

278 timeout=180, 

279 time_sleep=2, 

280 progress_bar=False, 

281): 

282 """Download files from a list of URLs using a ThreadPoolExecutor with a given number of workers. 

283 

284 Args: 

285 - directory (str): the directory where to save the downloaded files 

286 - urls (List[str]): a list of URLs to download 

287 - nb_workers (int): the number of workers for the ThreadPoolExecutor 

288 - timeout (int): the maximum time to wait for a response from the server, in seconds 

289 - time_sleep (int): the time to sleep between two requests, in seconds 

290 - progress_bar (bool): whether to show a progress bar or not 

291 

292 Raises: 

293 - requests.exceptions.ConnectionError: if a connection error occurs while downloading a file 

294 """ 

295 

296 def scrape(url): 

297 """Download a file from a URL. 

298 

299 Args: 

300 - url (str): the URL to download 

301 

302 Returns: 

303 - url (str): the URL that has been downloaded 

304 """ 

305 start = time.time() 

306 filepath: str = compute_downloaded_filepath(directory, url) 

307 with requests_retry_session(): 

308 simple_download(url, filepath, timeout) 

309 time.sleep(time_sleep) 

310 end = time.time() 

311 hours, rem = divmod(end - start, 3600) 

312 minutes, seconds = divmod(rem, 60) 

313 file_size_bytes = os.path.getsize(filepath) 

314 file_size_mb = file_size_bytes / 1024**2 

315 ProgressLogger.write( 

316 f"{url} downloaded ({file_size_mb:0.03f} MB) in {int(hours):0>2}:{int(minutes):0>2}:{seconds:05.2f}", 

317 logger, 

318 ) 

319 return url 

320 

321 if len(urls) == 0: 

322 return 

323 

324 if not os.path.exists(directory): 

325 os.makedirs(directory, exist_ok=True) 

326 

327 progress_logger = ProgressLogger( 

328 total=len(urls), 

329 iterable=None, 

330 logger=logger, 

331 description="Downloading data", 

332 position=1, 

333 leave=False, 

334 disable_tqdm=not progress_bar, 

335 ) 

336 

337 with concurrent.futures.ThreadPoolExecutor( 

338 max_workers=nb_workers 

339 ) as executor: 

340 futures = [executor.submit(scrape, url) for url in urls] 

341 for future in concurrent.futures.as_completed(futures): 

342 try: 

343 future.result() 

344 except requests.exceptions.ConnectionError as err: 

345 logger.exception(f"[parallel_requests]: {err}") 

346 progress_logger.update(n=1) 

347 

348 progress_logger.close() 

349 

350 

351def compute_downloaded_filepath(directory: str, url: str) -> str: 

352 """Computes the file path where a downloaded file will be saved based on the provided URL and directory. 

353 

354 Args: 

355 directory (str): The directory where the downloaded file will be saved. 

356 url (str): The URL of the downloaded file. 

357 

358 Returns: 

359 str: The file path where the downloaded file will be saved. 

360 """ 

361 # Parse the URL to extract any query parameters 

362 parsed_url = urlparse(url) 

363 params: Dict[str, str] = parse_qs(parsed_url.query) # type: ignore 

364 

365 # Generate the filename based on the query parameters or the URL path 

366 filename: str 

367 if "ihid" in params: 

368 # If the URL contains "ihid" parameter, create a filename using "target", "ihid", "iid", "pt", and "offset" parameters 

369 filename = f"{params['target'][0]}_{params['ihid'][0]}_{params['iid'][0]}_{params['pt'][0]}_{params['offset'][0]}.json" 

370 filename = filename.replace(os.path.sep, "_") 

371 else: 

372 # If the URL doesn't contain "ihid" parameter, create a filename using the URL path 

373 path: str = parsed_url.path 

374 filename: str = os.path.basename(path).lower() 

375 

376 # Return the full file path where the downloaded file will be saved 

377 return os.path.join(directory, filename) 

378 

379 

380def compute_download_directory_path( 

381 directory: str, target: str, ihid: str, iid: str, pt: str, ds: str 

382) -> str: 

383 """Computes the path where the file is downloaded based on a base directory and a metadata coming from the PDS. 

384 

385 Args: 

386 directory (str): base direcory 

387 target (str): solar body 

388 ihid (str): plateform 

389 iid (str): instrument 

390 pt (str): product type 

391 ds (str): collection 

392 

393 Returns: 

394 str: the path of the directory 

395 """ 

396 # Create a list with the names of the five required folders/files 

397 items = [ 

398 item.lower().replace( 

399 os.path.sep, "_" 

400 ) # Replace occurrences of os.path.sep with underscores in each name 

401 for item in [ 

402 target, 

403 ihid, 

404 iid, 

405 pt, 

406 ds, 

407 ] # For each required name, make a lowercase copy 

408 ] 

409 # Join all required names using os.path.sep as path separator and add the base path 

410 return os.path.join(directory, os.path.sep.join(items)) 

411 

412 

413def utc_to_iso(utc_time: str, timespec: str = "auto") -> str: 

414 """Convert UTC time string to ISO format string (STAC standard).""" 

415 # set valid datatime formats 2018-08-23T23:24:36.865Z 

416 valid_formats = [ 

417 "%Y-%m-%dT%H:%M:%S", 

418 "%Y-%m-%dT%H:%M:%S.%f", 

419 "%Y-%m-%dT%H:%M:%S.%fZ", 

420 "%Y-%m-%dT%H:%M:%SZ", 

421 "%Y-%m-%d", 

422 ] 

423 for valid_format in valid_formats: 

424 try: 

425 return datetime.strptime(utc_time, valid_format).isoformat( 

426 timespec=timespec 

427 ) 

428 except: # noqa: E722 

429 continue 

430 raise DateConversionError( 

431 f"Cannot convert in ISO str this time {utc_time} with the following patterns {valid_formats}" 

432 ) 

433 

434 

435class Observable: 

436 """Observable""" 

437 

438 def __init__(self): 

439 """Init the observable""" 

440 self._observers = list() 

441 

442 def subscribe(self, observer): 

443 """Subscribe the observable to the observer 

444 

445 Args: 

446 observer (Observer): Observer 

447 """ 

448 self._observers.append(observer) 

449 

450 def notify_observers(self, *args, **kwargs): 

451 """Notify the observers""" 

452 for obs in self._observers: 

453 obs.notify(self, *args, **kwargs) 

454 

455 def unsubscribe(self, observer): 

456 """Unsubscribe the observers 

457 

458 Args: 

459 observer (Observer): Observer 

460 """ 

461 self._observers.remove(observer) 

462 

463 def unsubscribe_all(self): 

464 """Unsubscribe all observers.""" 

465 self._observers.clear() 

466 

467 

468class Observer: # pylint: disable=R0903 

469 """Observer""" 

470 

471 def __init__(self, observable): 

472 """Init the observer 

473 

474 Args: 

475 observable (Observable): Observable to observe 

476 """ 

477 observable.subscribe(self) 

478 

479 def notify(self, observable, *args, **kwargs): # pylint: disable=R0201 

480 """Notify 

481 

482 Args: 

483 observable (Observable): Observable 

484 """ 

485 print("Got", args, kwargs, "From", observable) 

486 

487 

488class UtilsMonitoring: # noqa: R0205 

489 """Some Utilities.""" 

490 

491 # pylint: disable:invalid_name 

492 @staticmethod 

493 def io_display( 

494 func=None, input=True, output=True, level=15 

495 ): # pylint: disable=W0622 

496 """Monitor the input/output of a function. 

497 

498 NB : Do not use this monitoring method on an __init__ if the class 

499 implements __repr__ with attributes 

500 

501 Parameters 

502 ---------- 

503 func: func 

504 function to monitor (default: {None}) 

505 input: bool 

506 True when the function must monitor the input (default: {True}) 

507 output: bool 

508 True when the function must monitor the output (default: {True}) 

509 level: int 

510 Level from which the function must log 

511 Returns 

512 ------- 

513 object : the result of the function 

514 """ 

515 if func is None: 

516 return partial( 

517 UtilsMonitoring.io_display, 

518 input=input, 

519 output=output, 

520 level=level, 

521 ) 

522 

523 @wraps(func) 

524 def wrapped(*args, **kwargs): 

525 name = func.__qualname__ 

526 logger = logging.getLogger(__name__ + "." + name) 

527 

528 if input and logger.getEffectiveLevel() >= level: 

529 msg = f"[{name}] Entering '{name}' (args={args}, kwargs={kwargs})" 

530 logger.log(level, msg) 

531 

532 result = func(*args, **kwargs) 

533 

534 if output and logger.getEffectiveLevel() >= level: 

535 msg = f"[{name}] Exiting '{name}' (result={result})" 

536 logger.log(level, msg) 

537 

538 return result 

539 

540 return wrapped 

541 

542 @staticmethod 

543 def timeit(func): 

544 """Decorator to measure the time spent in an function""" 

545 

546 @wraps(func) 

547 def timeit_wrapper(*args, **kwargs): 

548 start_time = time.perf_counter() 

549 result = func(*args, **kwargs) 

550 end_time = time.perf_counter() 

551 total_time = end_time - start_time 

552 logger.info( 

553 f"Function {func.__name__}{args} {kwargs} Took {total_time:.4f} seconds" 

554 ) 

555 return result 

556 

557 return timeit_wrapper 

558 

559 @staticmethod 

560 def measure_memory(func=None, level=logging.DEBUG): 

561 """Measure the memory of the function 

562 

563 Args: 

564 func (func, optional): Function to measure. Defaults to None. 

565 level (int, optional): Level of the log. Defaults to logging.INFO. 

566 

567 Returns: 

568 object : the result of the function 

569 """ 

570 if func is None: 

571 return partial(UtilsMonitoring.measure_memory, level=level) 

572 

573 @wraps(func) 

574 def newfunc(*args, **kwargs): 

575 name = func.__qualname__ 

576 logger = logging.getLogger(__name__ + "." + name) 

577 tracemalloc.start() 

578 result = func(*args, **kwargs) 

579 current, peak = tracemalloc.get_traced_memory() 

580 msg = f""" 

581 \033[37mFunction Name :\033[35;1m {func.__name__}\033[0m 

582 \033[37mCurrent memory usage:\033[36m {current / 10 ** 6}MB\033[0m 

583 \033[37mPeak :\033[36m {peak / 10 ** 6}MB\033[0m 

584 """ 

585 logger.log(level, msg) 

586 tracemalloc.stop() 

587 return result 

588 

589 return newfunc 

590 

591 

592class ProgressLogger: 

593 """A progress logger that can be used with or without tqdm.""" 

594 

595 def __init__( 

596 self, 

597 total: int, 

598 logger: logging.Logger, 

599 iterable: Union[Iterable, None] = None, 

600 description: str = "processing", 

601 disable_tqdm=False, 

602 *args, 

603 **kwargs, 

604 ): 

605 """A progress logger that can be used with or without tqdm. 

606 

607 Args: 

608 total (int): The total number of items to be processed. 

609 logger (logging.Logger): The logger to use for progress updates. 

610 iterable (Union[Iterable, None]): The iterable to be processed. 

611 description (str): The description of the progress bar. 

612 disable_tqdm (bool): Whether to disable tqdm progress bar. 

613 *args: Additional positional arguments to be passed to tqdm. 

614 **kwargs: Additional keyword arguments to be passed to tqdm. 

615 

616 Returns: 

617 None 

618 """ 

619 self.total = total 

620 self.disable_tqdm = disable_tqdm 

621 self.description = description 

622 self.logger = logger 

623 self.iterable: Union[Iterable[str], None] = iterable 

624 self.nb: int = 0 

625 self.kwargs = kwargs 

626 if self.iterable is None: 

627 self.pbar = tqdm( 

628 total=total, 

629 desc=description, 

630 disable=self.disable_tqdm, 

631 **self.kwargs, 

632 ) 

633 self._send_message() 

634 

635 def _send_message(self): 

636 """Sends progress update messages to the logger at regular intervals.""" 

637 if self.disable_tqdm and self.nb % 10 == 0: 

638 msg = f"{self.description} : {int(self.nb/self.total)}%" 

639 self.logger.info(msg) 

640 

641 def __enter__(self): 

642 """Called when the 'with' statement is entered. Initializes the progress bar. 

643 

644 Returns: 

645 self: The ProgressLogger instance. 

646 """ 

647 self.pbar = tqdm( 

648 total=self.total, 

649 desc=self.description, 

650 disable=self.disable_tqdm, 

651 **self.kwargs, 

652 ) 

653 self._send_message() 

654 return self 

655 

656 def __exit__(self, exc_type, exc_value, traceback): 

657 """Called when the 'with' statement is exited. Closes the progress bar. 

658 

659 Args: 

660 exc_type (type): The type of the exception, if one occurred. 

661 exc_value (Exception): The exception object, if one occurred. 

662 traceback (traceback): The traceback object, if one occurred. 

663 """ 

664 if self.pbar: 

665 self.pbar.close() 

666 

667 def __iter__(self): 

668 """Iterates over the iterable and updates the progress bar. 

669 

670 Args: 

671 None 

672 

673 Yields: 

674 i: The next item in the iterable. 

675 """ 

676 if self.iterable and not self.disable_tqdm: 

677 for i in self.iterable: 

678 self.pbar.update(1) 

679 yield i 

680 elif self.iterable: 

681 for i in self.iterable: 

682 self.nb += 1 

683 self._send_message() 

684 yield i 

685 else: 

686 yield None 

687 

688 def update(self, n: int): 

689 """Updates the progress bar with the specified number of items. 

690 

691 Args: 

692 n (int): The number of items to update. 

693 """ 

694 if self.disable_tqdm: 

695 self.nb += 1 

696 self._send_message() 

697 else: 

698 self.pbar.update(n=n) 

699 

700 def write_msg(self, msg: str): 

701 """Write a message using tqdm or the logger according if tqdm is used or not.""" 

702 if self.disable_tqdm: 

703 logger.info(msg) 

704 else: 

705 tqdm.write(msg) 

706 

707 @staticmethod 

708 def write(msg: str, logger: logging.Logger): 

709 """Write a message using the logger.""" 

710 logger.info(msg) 

711 

712 def close(self): 

713 """Close the tqdm progress bar.""" 

714 if self.pbar: 

715 self.pbar.close() 

716 

717 

718class Locking: 

719 """Utility class for locking a file""" 

720 

721 @staticmethod 

722 def lock_file(file): 

723 """Creates a lock file with the same name as the input file, but with ".lock" appended to the end. 

724 This method can be used to prevent other processes from accessing the same file simultaneously. 

725 

726 Args: 

727 file (str): The name of the file to lock. 

728 

729 Returns: 

730 None 

731 """ 

732 # Create a lock file by appending ".lock" to the input file name 

733 lock_file = file + ".lock" 

734 

735 # Initialize the locking flag to False 

736 locking = False 

737 

738 # While the file is not locked, keep trying to create the lock file 

739 while not locking: 

740 try: 

741 # Try to create the lock file using the mkdir method 

742 os.mkdir(lock_file) 

743 

744 # If successful, set the locking flag to True and exit the loop 

745 locking = True 

746 except OSError: 

747 # If the lock file already exists, wait for 0.1 seconds and try again 

748 time.sleep(0.1) 

749 

750 @staticmethod 

751 def unlock_file(file: str): 

752 """Remove the lock file 

753 

754 Args: 

755 file (str): The name of the file to unlock. 

756 """ 

757 os.rmdir(file + ".lock")