Source code for iceprod.core.functions

"""
Common functions
"""

from __future__ import absolute_import, division, print_function

import os
import shutil
import logging
import socket
import subprocess
import hashlib
from functools import partial, reduce
from contextlib import contextmanager
import asyncio

try:
    import psutil
except ImportError:
    psutil = None

import requests
from requests_toolbelt.multipart.encoder import MultipartEncoder
from rest_tools.client import Session, AsyncSession

from iceprod.core.gridftp import GridFTP


# Compression Functions #
_compress_suffixes = ('.tgz','.gz','.tbz2','.tbz','.bz2','.bz',
                      '.lzma2','.lzma','.lz','.xz')
_tar_suffixes = ('.tar', '.tar.gz', '.tgz', '.tar.bz2', '.tbz2', '.tbz',
                 '.tar.lzma', '.tar.xz', '.tlz', '.txz')


[docs] def uncompress(infile, out_dir=None): """Uncompress a file, if possible""" files = [] cur_dir = os.getcwd() try: if out_dir: os.chdir(out_dir) logging.info('uncompressing %s',infile) if istarred(infile): # handle tarfile output = subprocess.check_output(['tar','-atf',infile]).decode('utf-8') files = [x for x in output.split('\n') if x.strip() and x[-1] != '/'] if not files: raise Exception('no files inside tarfile') for f in files: if os.path.exists(f): break else: subprocess.call(['tar','-axf',infile]) else: if infile.endswith('.gz'): cmd = 'gzip' elif any(infile.endswith(s) for s in ('.bz','.bz2')): cmd = 'bzip2' elif any(infile.endswith(s) for s in ('.xz','.lzma')): cmd = 'xz' else: logging.info('unknown format: %s',infile) raise Exception('unknown format') subprocess.call([cmd,'-kdf',infile]) files.append(infile.rsplit('.',1)[0]) finally: os.chdir(cur_dir) logging.info('files: %r', files) if len(files) == 1: return files[0] else: return files
[docs] def compress(infile,compression='lzma'): """Compress a file or directory. The compression argument is used as the new file extension""" if not istarred('.'+compression) and os.path.isdir(infile): outfile = infile+'.tar.'+compression else: outfile = infile+'.'+compression if istarred(outfile): dirname, filename = os.path.split(infile) subprocess.call(['tar','-acf',outfile,'-C',dirname,filename]) else: if outfile.endswith('.gz'): cmd = ['gzip'] elif any(outfile.endswith(s) for s in ('.bz','.bz2')): cmd = ['bzip2'] elif outfile.endswith('.xz'): cmd = ['xz'] elif outfile.endswith('.lzma'): cmd = ['xz','-F','lzma'] else: logging.info('unknown format: %s',infile) raise Exception('unknown format') subprocess.call(cmd+['-kf',infile]) return outfile
[docs] def iscompressed(infile): """Check if a file is a compressed file, based on file name""" return any(infile.endswith(s) for s in _compress_suffixes)
[docs] def istarred(infile): """Check if a file is a tarred file, based on file name""" return any(infile.endswith(s) for s in _tar_suffixes)
[docs] def cksm(filename,type,buffersize=16384,file=True): """Return checksum of file using algorithm specified""" if type not in ('md5','sha1','sha256','sha512'): raise Exception('cannot get checksum for type %r',type) try: digest = getattr(hashlib,type)() except Exception: raise Exception('cannot get checksum for type %r',type) if file and os.path.exists(filename): # checksum file contents with open(filename,'rb') as filed: buffer = filed.read(buffersize) while buffer: digest.update(buffer) buffer = filed.read(buffersize) else: # just checksum the contents of the first argument digest.update(filename) return digest.hexdigest()
[docs] def md5sum(filename,buffersize=16384): """Return md5 digest of file""" return cksm(filename,'md5',buffersize)
[docs] def sha1sum(filename,buffersize=16384): """Return sha1 digest of file""" return cksm(filename,'sha1',buffersize)
[docs] def sha256sum(filename,buffersize=16384): """Return sha256 digest of file""" return cksm(filename,'sha256',buffersize)
[docs] def sha512sum(filename,buffersize=16384): """Return sha512 digest of file""" return cksm(filename,'sha512',buffersize)
[docs] def load_cksm(sumfile, base_filename): """Load the checksum from a file""" for line in open(sumfile, 'r'): if os.path.basename(base_filename) in line: sum_cksm, name = line.strip('\n').split() return sum_cksm raise Exception('could not find checksum in file')
[docs] def check_cksm(file,type,sum): """Check a checksum of a file""" if not os.path.exists(file): return False # get checksum from file file_cksm = cksm(file,type) # load sum if os.path.isfile(sum): sum_cksm = load_cksm(sum, file) else: sum_cksm = sum # check sum logging.debug('file_cksm: %r', file_cksm) logging.debug('sum_cksm: %r', sum_cksm) return (file_cksm == sum_cksm)
[docs] def check_md5sum(file,sum): """Check an md5sum of a file""" return check_cksm(file,'md5',sum)
[docs] def check_sha1sum(file,sum): """Check an sha1sum of a file""" return check_cksm(file,'sha1',sum)
[docs] def check_sha256sum(file,sum): """Check an sha256sum of a file""" return check_cksm(file,'sha256',sum)
[docs] def check_sha512sum(file,sum): """Check an sha512sum of a file""" return check_cksm(file,'sha512',sum)
# File and Directory Manipulation Functions #
[docs] def removedirs(path): try: if os.path.isdir(path): shutil.rmtree(path,True) else: os.remove(path) except Exception: pass
[docs] def copy(src,dest): parent_dir = os.path.dirname(dest) if not os.path.exists(parent_dir): logging.info('attempting to make parent dest dir %s',parent_dir) try: os.makedirs(parent_dir) except Exception: logging.error('failed to make dest directory for copy',exc_info=True) raise if os.path.isdir(src): logging.info('dircopy: %s to %s',src,dest) shutil.copytree(src,dest,symlinks=True) else: logging.info('filecopy: %s to %s',src,dest) shutil.copy2(src,dest)
# Network Functions #
[docs] def getInterfaces(): """ Get a list of available interfaces. Requires `psutil`. Returns: dict of {nic_name: {type: address}} """ interfaces = {} ret = psutil.net_if_addrs() for nic_name in ret: n = {} for snic in ret[nic_name]: if not snic.address: continue if snic.family == socket.AF_INET: n['ipv4'] = snic.address elif snic.family == socket.AF_INET6: n['ipv6'] = snic.address elif snic.family == psutil.AF_LINK: n['mac'] = snic.address interfaces[nic_name] = n return interfaces
[docs] def get_local_ip_address(): """Get the local (loopback) ip address""" try: return socket.gethostbyname('localhost') except Exception: return socket.gethostbyname(socket.getfqdn())
[docs] def gethostname(): """Get hostname of this computer.""" ret = socket.getfqdn() ret2 = socket.gethostname() if len(ret2) > len(ret): return ret2 else: return ret
@contextmanager def _http_helper(options={}, sync=True): """Set up an http session using requests""" if sync: session = Session else: session = AsyncSession with session(retries=10, backoff_factor=0.3) as s: if 'username' in options and 'password' in options: s.auth = (options['username'], options['password']) if 'sslcert' in options: if 'sslkey' in options: s.cert = (options['sslcert'], options['sslkey']) else: s.cert = options['sslcert'] if 'cacert' in options: s.verify = options['cacert'] if 'token' in options: s.headers.update({'Authorization': f'bearer {options["token"]}'}) yield s
[docs] async def download(url, local, options={}): """Download a file, checksumming if possible""" local = os.path.expanduser(os.path.expandvars(local)) url = os.path.expanduser(os.path.expandvars(url)) if not isurl(url): if os.path.exists(url): url = 'file:'+url else: raise Exception("unsupported protocol %s" % url) # strip off query params if '?' in url: clean_url = url[:url.find('?')] elif '#' in url: clean_url = url[:url.find('#')] else: clean_url = url # fix local to be a filename to write to if local.startswith('file:'): local = local[5:] if os.path.isdir(local): local = os.path.join(local, os.path.basename(clean_url)) logging.warning('wget(): src: %s, local: %s', url, local) # actually download the file try: if url.startswith('http'): logging.info('http from %s to %s', url, local) # http_proxy fix for k in os.environ: if k.lower() == 'http_proxy' and not os.environ[k].startswith('http'): os.environ[k] = 'http://'+os.environ[k] def _d(): with _http_helper(options) as s: r = s.get(url, stream=True, timeout=300) with open(local, 'wb') as f: for chunk in r.iter_content(65536): f.write(chunk) r.raise_for_status() await asyncio.get_event_loop().run_in_executor(None, _d) elif url.startswith('file:'): url = url[5:] logging.info('copy from %s to %s', url, local) if os.path.exists(url): await asyncio.get_event_loop().run_in_executor(None, partial(copy, url, local)) elif url.startswith('gsiftp:') or url.startswith('ftp:'): logging.info('gsiftp from %s to %s', url, local) await asyncio.get_event_loop().run_in_executor(None, partial(GridFTP.get, url, filename=local)) else: raise Exception("unsupported protocol %s" % url) if not os.path.exists(local): raise Exception('download failed - file does not exist') except Exception: await asyncio.get_event_loop().run_in_executor(None, removedirs, local) raise return local
[docs] async def upload(local, url, checksum=True, options={}): """Upload a file, checksumming if possible""" local = os.path.expandvars(local) url = os.path.expandvars(url) if not isurl(url): if url.startswith('/'): url = 'file:'+url else: raise Exception("unsupported protocol %s" % url) if local.startswith('file:'): local = local[5:] if os.path.isdir(local): compress(local, 'tar') local += '.tar' logging.warning('wput(): local: %s, url: %s', local, url) if not os.path.exists(local): logging.warning('upload: local path, %s, does not exist', local) raise Exception('local file does not exist') chksum = sha512sum(local) chksum_type = 'sha512' if not checksum: logging.warning(f'not performing checksum {chksum_type}\n{checksum}: {url}') # actually upload the file if url.startswith('http'): logging.info('http from %s to %s', local, url) # http_proxy fix for k in os.environ: if k.lower() == 'http_proxy' and not os.environ[k].startswith('http'): os.environ[k] = 'http://'+os.environ[k] def _d(): with _http_helper(options) as s: try: with open(local, 'rb') as f: r = s.put(url, timeout=300, data=f) r.raise_for_status() except requests.exceptions.HTTPError as e: if e.response.status_code != 405: raise else: logging.warning('WebDav PUT not allowed, trying multipart upload') with open(local, 'rb') as f: m = MultipartEncoder( fields={'field0': ('filename', f, 'text/plain')} ) r = s.post(url, timeout=300, data=m, headers={'Content-Type': m.content_type}) r.raise_for_status() if checksum: # get checksum if 'ETAG' in r.headers: md5 = r.headers['ETAG'].strip('"\'') if md5sum(local) != md5: raise Exception('http checksum error') else: r = s.get(url, stream=True, timeout=300) try: with open(local+'.tmp', 'wb') as f: for chunk in r.iter_content(65536): f.write(chunk) r.raise_for_status() if sha512sum(local+'.tmp') != chksum: raise Exception('http checksum error') finally: removedirs(local+'.tmp') await asyncio.get_event_loop().run_in_executor(None, _d) elif url.startswith('file:'): # use copy command url = url[5:] def _c(): if os.path.exists(url): logging.warning('put: file already exists. overwriting!') removedirs(url) copy(local, url) if checksum and sha512sum(url) != chksum: raise Exception('file checksum error') await asyncio.get_event_loop().run_in_executor(None, _c) elif url.startswith('gsiftp:') or url.startswith('ftp:'): def _g(): try: GridFTP.put(url, filename=local) except Exception: # because d-cache doesn't allow overwriting, try deletion GridFTP.delete(url) GridFTP.put(url, filename=local) if checksum and GridFTP.sha512sum(url) != chksum: raise Exception('gridftp checksum error') await asyncio.get_event_loop().run_in_executor(None, _g) else: raise Exception("unsupported protocol %s" % url)
[docs] def delete(url, options={}): """Delete a url or file""" url = os.path.expandvars(url) if (not isurl(url)) and os.path.exists(url): url = 'file:'+url if url.startswith('http'): logging.info('delete http: %s', url) with _http_helper(options) as s: r = s.delete(url, timeout=300) r.raise_for_status() elif url.startswith('file:'): url = url[5:] logging.info('delete file: %r', url) if os.path.exists(url): removedirs(url) elif url.startswith('gsiftp:') or url.startswith('ftp:'): logging.info('delete gsiftp: %r', url) GridFTP.rmtree(url) else: raise Exception("unsupported protocol %s" % url)
[docs] def isurl(url): """Determine if this is a supported protocol""" prefixes = ('file:','http:','https:','ftp:','ftps:','gsiftp:') try: return url.startswith(prefixes) except Exception: try: return reduce(lambda a,b: a or url.startswith(b), prefixes, False) except Exception: return False