#!/usr/bin/env python import os import time import sys import csv import glob import tempfile import shutil import platform import argparse import numpy as np import surfa as sf import scipy.ndimage # defer tensorflow import until we need it (for faster command-line parsing) tf = None description = """ Segment subcortical limbic structures. Input images can be provided by one of two methods. To segment one or multiple T1-weighted images, use the --i flag to point to an input image file or directory containing a series of images. The --o flag should specify the corresponding output segmentation file or directory. For example: mri_sclimbic_seg --i image.mgz --o seg.mgz To process a series of freesurfer recon-all subjects, use the --s input flag. When no arguments are provided to this flag, subjects will be searched for in the 'subjects directory' defined by the --sd flag or the SUBJECTS_DIR env variable. Otherwise, a set of subject names can be specified as arguments. For example: mri_sclimbic_seg --s subj1 subj2 subj3 In freesurfer subject-mode, outputs will be saved to the subject's mri and stats subdirectories, and volumetric stats will be computed and saved automatically. """ # ------------------------------------------------------------------------------------------------ # Main Entrypoint # ------------------------------------------------------------------------------------------------ def main(): # configure command-line parser = argparse.ArgumentParser(description=description) # normal-mode options parser.add_argument('-i', '--i', help='T1-w image(s) to segment. Can be a path to a single image or a directory of images.') parser.add_argument('-o', '--o', help='Segmentation output (required if --i is provided). Must be the same type as ' 'the input path (a single file or directory).') # subject-mode options parser.add_argument('-s', '--s', nargs='*', help='Process a series of freesurfer recon-all subjects (enables subject-mode).') parser.add_argument('--sd', help='Set the subjects directory (overrides the SUBJECTS_DIR env variable).') # general options parser.add_argument('--conform', action='store_true', help='Resample input to 1mm-iso; results will be put back in native resolution.') parser.add_argument('--etiv', action='store_true', help='deInclude eTIV in volume stats (enabled by default in subject-mode and --tal).') parser.add_argument('--tal', help='Alternative talairach xfm transform for estimating TIV. Can be file or suffix (for multiple inputs).') parser.add_argument('--write_posteriors', action='store_true', help='Save the label posteriors.') parser.add_argument('--write_volumes', action='store_true', help='Save label volume stats (enabled by default in subject-mode).') parser.add_argument('--write_qa_stats', action='store_true', help='Save QA stats (z and confidence).') parser.add_argument('--exclude', type=int, nargs='+', default=[], help='List of label IDs to exclude in any output stats files.') parser.add_argument('--keep_ac', action='store_true', help='Explicitly keep anterior commissure in the volume/qa files.') parser.add_argument('--vox-count-volumes', action='store_true', help='Use discrete voxel count for label volumes.') parser.add_argument('--model', help='Alternative model weights to load.') parser.add_argument('--ctab', help='Alternative color lookup table to embed in segmentation. Must be minimal, including 0, and sorted.') parser.add_argument('--population-stats', help='Alternative population volume stats for QA output.') parser.add_argument('--debug', action='store_true', help='Enable debug logging.') parser.add_argument('--vmp', action='store_true', help='Enable printing of vmpeak at the end.') parser.add_argument('--threads', type=int, default=1, help='Number of threads to use. Default is 1.') parser.add_argument('--7T', dest='sevenT', action='store_true', help='Preprocess 7T images (just sets percentile to 99.9).') parser.add_argument('--percentile', type=float, help='Use intensity percentile threshold for normalization.') parser.add_argument('--cuda-device', help='Cuda device for GPU support.') parser.add_argument('--output-base', default='sclimbic',help='String to use in output file name; default is sclimbic') # Ideally, we would get this from the model parser.add_argument('--nchannels', type=int, default=1,help='Number of channels') # check for no arguments if len(sys.argv) < 2: parser.print_help() sys.exit(1) # print out the command line print(' '.join(sys.argv)) # parse commandline args = parser.parse_args() # a few sanity checks on the command-line inputs if args.i is None and args.s is None: sf.system.fatal('Input image(s) or subject(s) to segment must be provided with the --i or --s flags.') if args.i is not None and args.s is not None: sf.system.fatal('Cannot provide both input image (--i) and subject (--s) flags. Choose one input mode.') if args.i is not None and args.o is None: sf.system.fatal('--o output flag must be provided if --i input is used.') # Automatically exclude AntCom unless explicitly kept if not args.keep_ac: # not explicity being kept if not (853 in args.exclude): # not already in the list args.exclude.append(853) # add it to the list if not (853 in args.exclude): print('Keeping anterior commissure in vols and stats') if len(args.exclude) > 0: print("Excluding seg", args.exclude) if args.tal is not None: args.etiv = 1 # check for fs home if not os.environ.get('FREESURFER_HOME'): sf.system.fatal('FREESURFER_HOME is not set. Please source FreeSurfer.') # configure cuda device if args.cuda_device is not None: os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda_device cuda_device = os.getenv('CUDA_VISIBLE_DEVICES') if cuda_device is None or cuda_device == '-1': os.environ['CUDA_VISIBLE_DEVICES'] = '-1' print('Using CPU') else: print('Using GPU device', cuda_device) # defer tensorflow importing until after parsing os.environ['TF_CPP_MIN_LOG_LEVEL'] = '0' if args.debug else '3' global tf import tensorflow as tf if not args.debug: tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) # set number of threads print('Using %d thread(s)' % args.threads) tf.config.threading.set_inter_op_parallelism_threads(args.threads) tf.config.threading.set_intra_op_parallelism_threads(args.threads) percentile = args.percentile if args.sevenT: print('7T image flag provided, using 99.9 percentile normalization') percentile = 99.9 # load lookup table for segmentation lut_file = args.ctab if args.ctab else os.path.join(os.environ.get('FREESURFER_HOME'), 'models', 'sclimbic.ctab') labels = sf.load_label_lookup(lut_file) print('Loaded lookup table', lut_file) # load population stats for QA purposes if args.population_stats: pop_stats_file = args.population_stats else: pop_stats_file = os.path.join(os.environ.get('FREESURFER_HOME'), 'models', 'sclimbic.volstats.csv') with open(pop_stats_file, 'r') as csvfile: population_stats = {row.pop('VolStat_mm3'): row for row in csv.DictReader(csvfile)} print('Loaded population stats', pop_stats_file) # load model weights and initialize segmenter model_file = args.model if args.model else os.path.join(os.environ.get('FREESURFER_HOME'), 'models', 'sclimbic.fsm+ad.t1.nstd00-50.nstd32-50.h5') segmenter = LimbicSegmenter(model_file=model_file, labels=labels, population_stats=population_stats, conform=args.conform, store_etiv=args.etiv, store_qa_stats=args.write_qa_stats, debug=args.debug, volumes_from_vox_count=args.vox_count_volumes, percentile=percentile, exclude=args.exclude, nchannels=args.nchannels); print('Loaded model weights', model_file) # loop through each image and segment if args.s is not None: # in freesurfer subject-mode, we can grab the subject's nu.mgz # and use it as input sd = os.getenv('SUBJECTS_DIR') if args.sd is None else args.sd if sd is None: sf.system.fatal('Must set subjects directory with --sd or SUBJECTS_DIR env variable.') summary_file_prefix = os.path.join(sd, args.output_base + '_') print('Using subject directory', sd) # if no subjects have been provided, let's search the subjects directory subjects = args.s if len(subjects) == 0: nus = glob.glob(f'{sd}/*/mri/nu.mgz') subjects = [os.path.basename(n.replace('/mri/nu.mgz', '')) for n in nus] # we still haven't found anything if len(subjects) == 0: sf.system.fatal(f'Subjects directory {sd} does not contain any valid recon-all subjects.') # loop through subjects, set filenames, and segment for n, subj in enumerate(subjects): # sanity check the subject subjdir = os.path.join(sd, subj) if not os.path.isdir(subjdir): sf.system.fatal(f'Recon-all subject {subj} does not exist in {sd}.') # set default IO parameters params = { 'input_file': os.path.join(subjdir, 'mri', 'nu.mgz'), 'segmentation_path': os.path.join(subjdir, 'mri', args.output_base+'.mgz'), 'volumes_path': os.path.join(subjdir, 'stats', args.output_base+'.stats'), 'case_name': subj, } # estimate TIV from talairach lta if args.tal: lta = os.path.join(subjdir, args.tal) else: lta = os.path.join(subjdir, 'mri', 'transforms', 'talairach.xfm.lta') params['etiv'] = compute_etiv_from_lta(lta) print('Computed eTIV from talairach') # save posterior data to the mri subdir if args.write_posteriors: params['posteriors_path'] = os.path.join(subjdir, 'mri', args.output_base+'.posteriors.mgz') # save QA output to the stats subdir if args.write_qa_stats: params['qa_stats_path'] = os.path.join(subjdir, 'stats', args.output_base+'.qa.stats') print('\nSegmenting subject %s %d/%d' % (subj, n + 1, len(subjects))) segmenter.process_files(**params) segmenter.write_all_case_volumes(summary_file_prefix + 'volumes_all.csv') else: # normal file/directory input mode # available file formats exts = ('.mgh', '.mgz', '.nii', '.nii.gz') isdir = os.path.isdir(args.i) if isdir: # find valid images in input directory input_files = [] for ext in exts: input_files += sorted(glob.glob(os.path.join(args.i, '*' + ext))) if len(input_files) == 0: sf.system.fatal(f'Could not find any valid input images in {args.i}.') os.makedirs(args.o, exist_ok=True) logfile = os.path.join(args.o,'mri_sclimbic.log'); else: if not args.i.endswith(exts): sf.system.fatal(f'{args.i} is an unsupported image file type.') input_files = [args.i] dirname = os.path.dirname(args.o) if dirname: os.makedirs(dirname, exist_ok=True) logfile = os.path.join(dirname,'mri_sclimbic.log'); # Write the command line. Prob not the best way to do it. with open(logfile, 'w') as file: file.write('modelfile ' + model_file + '\n'); file.write('ctab ' + lut_file + '\n'); file.write('cd ' + os.getcwd() + '\n'); file.write(' '.join(sys.argv) + '\n'); # quick utility to add to filenames while keeping extension def split_extension(filename): for ext in exts: if filename.endswith(ext): return (filename[:-len(ext)], ext) sf.system.fatal(f'{filename} is an unsupported image file type.') # loop through the images and segment for n, input_file in enumerate(input_files): params = {'input_file': input_file} # some logic to determine output filename if isdir: true_basename, ext = split_extension(os.path.basename(input_file)) params['case_name'] = true_basename basename = os.path.join(args.o, true_basename + '.'+args.output_base) params['segmentation_path'] = f'{basename}{ext}' else: basename, ext = split_extension(args.o) params['case_name'] = basename params['segmentation_path'] = args.o # optional outputs if args.write_posteriors: params['posteriors_path'] = f'{basename}.posteriors{ext}' if args.write_volumes: params['volumes_path'] = f'{basename}.stats' if args.write_qa_stats: params['qa_stats_path'] = f'{basename}.qa.stats' # optional TIV estimation (really only applies to volume stats) if args.write_volumes and args.etiv: if args.tal: if isdir: xfm = os.path.join(args.i, true_basename + args.tal) else: xfm = args.tal params['etiv'] = compute_etiv_from_lta(xfm) print('Computed eTIV from talairach file') else: print("Computing etiv from scratch") params['etiv'] = compute_etiv_from_scratch(input_file) print('\nSegmenting image %d/%d' % (n + 1, len(input_files))) segmenter.process_files(**params) if isdir: summary_file_prefix = os.path.join(args.o, args.output_base+ '_') else: summary_file_prefix = f'{basename}.' # write all case volumes in output directory if os.path.isdir(args.o) and len(input_files) > 0: segmenter.write_all_case_volumes(summary_file_prefix + 'volumes_all.csv') # write qa stats if args.write_qa_stats: segmenter.write_all_case_zscores(summary_file_prefix + 'zqa_scores_all.csv') segmenter.write_all_case_confidences(summary_file_prefix + 'confidences_all.csv') # check memory usage if args.debug or args.vmp: print_vm_peak() # all done print('\nIf you use this tool in a publication, please cite:') print('A Deep Learning Toolbox for Automatic Segmentation of Subcortical Limbic Structures from MRI Images'); print('Greve, DN, Billot, B, Cordero, D, Hoopes, M. Hoffmann, A, Dalca, A, Fischl, B, Iglesias, JE, Augustinack, JC') print('Submitted') # ------------------------------------------------------------------------------------------------ # LimbicSegmenter # ------------------------------------------------------------------------------------------------ class LimbicSegmenter: """ Isolated class to handle image IO, preprocessing, prediction, and postprocessing """ def __init__(self, model_file, labels, population_stats, conform=True, store_etiv=False, store_qa_stats=False, debug=False, volumes_from_vox_count=False, percentile=None, exclude=[], nchannels=1): self.labels = labels self.population_stats = population_stats self.conform = conform self.inshape = (160, 160, 160) self.case_volumes = {} self.case_etivs = {} self.case_prob_means = {} self.store_etiv = store_etiv self.store_qa_stats = store_qa_stats self.last_time = time.time() self.debug = debug self.volumes_from_vox_count = volumes_from_vox_count self.percentile = percentile self.nchannels = nchannels; # build mask of labels to exclude in stats output files # always ignore unknown label self.exclude = [0] + exclude self.exclude_mask = [sid not in self.exclude for sid in labels.keys()] self.label_names = [label.name for i, label in zip(self.exclude_mask, self.labels.values()) if i] print(f'nb_labels {len(self.labels)}'); # build and load model self.model = unet(nb_features=24, input_shape=(*self.inshape, self.nchannels), nb_levels=3, conv_size=3, nb_labels=len(self.labels), name='unet', prefix=None, feat_mult=2, pool_size=2, padding='same', dilation_rate_mult=1, activation='elu', use_residuals=False, final_pred_activation='softmax', nb_conv_per_level=2, layer_nb_feats=None, conv_dropout=0, batch_norm=-1, input_model=None) self.model.load_weights(model_file, by_name=True) def reset_timer(self): """ Reset internal timer. """ self.last_time = time.time() def print_time(self, message): """ Print timer time if debugging is enabled. """ if self.debug: print('%s: %.4f s' % (message, time.time() - self.last_time)) def write_all_case_volumes(self, path): """ Write all case volumes to a csv. """ header = ['case'] + self.label_names if self.store_etiv: header.append('eTIV') with open(path, 'w') as file: file.write(','.join(header) + '\n') for case, volumes in self.case_volumes.items(): volumes = volumes[self.exclude_mask] if self.store_etiv: volumes = np.append(volumes, self.case_etivs[case]) file.write(','.join([case] + ['%.4f' % v for v in volumes])) file.write('\n') print('\nWrote summary of label volumes to', path) def write_all_case_zscores(self, path): """ Write all case z-scores to a csv. """ stats = [self.population_stats.get(label) for label in self.label_names] stat_mask = [idx for idx, stat in enumerate(stats) if stat is not None] labels = [label for label, stat in zip(self.label_names, stats) if stat is not None] mean = [float(stat['mean']) for stat in stats if stat is not None] std = [float(stat['std']) for stat in stats if stat is not None] with open(path, 'w') as file: file.write(','.join(['case'] + labels) + '\n') for case, volumes in self.case_volumes.items(): vol = volumes[self.exclude_mask][stat_mask] zscores = (vol - mean) / std file.write(','.join([case] + ['%.4f' % z for z in zscores]) + '\n') print('Wrote summary of label z-scores to', path) def write_all_case_confidences(self, path): """ Write all case confidences (mean prediction prob) to a csv. """ with open(path, 'w') as file: file.write(','.join(['case'] + self.label_names) + '\n') for case, prob_means in self.case_prob_means.items(): file.write(','.join([case] + ['%.4f' % v for v in prob_means[self.exclude_mask]]) + '\n') print('Wrote summary of label prediction confidences to', path) def preprocess(self, image): """ Preprocess an image by conforming it to the correct orientation, shape, and scale. """ # check resolution if not np.allclose(image.geom.voxsize, (1, 1, 1), rtol=0, atol=1e-2): image_geom_voxsize = [f'{x:4.2f}' for x in image.geom.voxsize] if self.conform: print(f'The input image has resolution {image_geom_voxsize} mm, but 1mm-isotropic input is required.\n' 'However, --conform has been specified, so the volume will be resliced to 1mm iso.\n') else: print('') sf.system.fatal(f'The input image has resolution {image_geom_voxsize}, but 1mm-isotropic input is required.\n' 'The volume can be resliced to 1mm-iso by specifying --conform (results may suffer).\n') # check channels if image.nframes != self.nchannels: sf.system.fatal(f'Input image has {image.nframes}, expecting {self.nchannels}.') # normalize image data dmin = image.min() dmax = image.max() if self.percentile is None else image.percentile(self.percentile, nonzero=True) if dmin == dmax: sf.system.fatal('Input image is blank!') image = (image.astype(np.float32) - dmin) / (dmax - dmin) image = image.clip(0, 1) # conform to RAS 1mm space processed = image.conform(shape=(*self.inshape, self.nchannels), voxsize=1.0, orientation='RAS', dtype='float32', copy=False) return processed def segment(self, image): """ Segment a raw input image. """ self.reset_timer() conformed = self.preprocess(image) self.print_time('Preprocess time') # posterior prediction self.reset_timer() prediction = self.model.predict(conformed.framed_data[np.newaxis]).squeeze() self.print_time('Prediction time') self.reset_timer() # let's clean up the posteriors a bit, but we'll do this in # a minimal cropped space to speed things up seg = conformed.new(prediction.argmax(-1)) bbox = seg.bbox(margin=2) cropped_seg = seg[bbox] posteriors = conformed.new(prediction)[bbox] # mask the posteriors around each label dilate_struct = build_binary_structure(1, 3) for label in range(1, len(self.labels)): cropped_pred_label = posteriors.data[..., label] label_mask = scipy.ndimage.binary_dilation(cropped_seg.data == label, dilate_struct) cropped_pred_label[np.logical_not(label_mask)] = 0 posteriors[..., label] = cropped_pred_label # ensure that the posteriors sum to 1 in the cropped space posteriors[..., 0] = 1.0 - np.sum(posteriors[..., 1:], axis=-1) posteriors = posteriors.clip(0, 1) posteriors /= np.sum(posteriors, axis=-1, keepdims=True) # resample cropped posteriors to original resolution posteriors = posteriors.resize(image.geom.voxsize, copy=False) # compute the final hard segmentation and compute voxel counts while we're at it vox_counts = [] mean_probs = [] argmax = posteriors.data.argmax(axis=-1) segmap = np.zeros(argmax.shape, dtype='int32') for n, nid in enumerate(self.labels.keys()): label_mask = argmax == n segmap[label_mask] = nid vox_counts.append(np.count_nonzero(label_mask)) if self.store_qa_stats: probs = posteriors.data[..., n][label_mask] mean_probs.append(probs.mean() if len(probs) > 0 else 0.0) vox_counts = np.array(vox_counts) mean_probs = np.array(mean_probs) # compute label volumes in original resolution voxvol = np.prod(posteriors.geom.voxsize) if self.volumes_from_vox_count: volumes = voxvol * np.array(vox_counts) else: volumes = voxvol * posteriors.data.reshape(-1, posteriors.shape[-1]).sum(0) # resample final hard segmentation to original space segmentation = posteriors.new(segmap).resample_like(image, method='nearest') segmentation.labels = self.labels self.print_time('Postprocess time') return (posteriors, segmentation, vox_counts, volumes, mean_probs) def process_files( self, input_file, segmentation_path, posteriors_path=None, volumes_path=None, qa_stats_path=None, etiv=None, case_name=None): # load image if not os.path.isfile(input_file): sf.system.fatal(f'Input image {input_file} does not exist') image = sf.load_volume(input_file) print('Loaded input image from', input_file) # segment post, seg, vox_counts, volumes, mean_probs = self.segment(image) # write segmentation seg.save(segmentation_path) print('Wrote segmentation to', segmentation_path) # write posteriors if posteriors_path is not None: post.save(posteriors_path) print('Wrote posteriors to', posteriors_path) # write volume stats in FS format if volumes_path is not None: with open(volumes_path, 'w') as file: file.write('# Subcortical Limbic Volumetric Stats\n') file.write('# Created by mri_sclimbic_seg\n') if etiv is not None: file.write('# Measure EstimatedTotalIntraCranialVol, eTIV, Estimated ' + \ f'Total Intracranial Volume, {etiv:.6f}, mm^3\n') label_matches = [(vid, nid) for (vid, nid) in enumerate(self.labels.keys()) if nid not in self.exclude] file.write(f'# NRows {len(label_matches)}\n') file.write('# NTableCols 5\n') file.write('# ColHeaders Index SegId NVoxels Volume_mm3 StructName\n') for n, (vid, nid) in enumerate(label_matches): file.write(f'{n+1: <4} {nid: >6}{vox_counts[vid]: >6}{volumes[vid]: >12.4f} {self.labels[nid].name}\n') print('Wrote volume stats to', volumes_path) # store label volumes self.case_volumes[case_name] = volumes if self.store_etiv: self.case_etivs[case_name] = etiv # write mean probs if self.store_qa_stats: self.case_prob_means[case_name] = mean_probs # ------------------------------------------------------------------------------------------------ # Utilities # ------------------------------------------------------------------------------------------------ def compute_etiv_from_lta(lta): """ Compute eTIV by loading the image or subject's talairach lta. """ scale_factor = 1948.106 etiv = 1e3 * scale_factor / sf.load_affine(lta).det() return etiv def compute_etiv_from_scratch(image): """ Compute eTIV by conforming, normalizing, and registering the input image to talairach space. This will slow down processing substantially. """ # make a temporary directory for the intermediate outputs tmpdir = tempfile.mkdtemp() norm = os.path.join(tmpdir, 'nu.mgz') xfm = os.path.join(tmpdir, 'talairach.xfm') lta = os.path.join(tmpdir, 'talairach.xfm.lta') # conform the input image ret = sf.system.run(f'mri_convert --conform {image} {norm}') if ret != 0: sf.system.fatal('mri_convert --conform failed!') # run intensity normalization ret = sf.system.run(f'mri_nu_correct.mni --no-rescale --i {norm} --o {norm} --n 1 --proto-iters 1000 --distance 50 --ants-n4') if ret != 0: sf.system.fatal('mri_nu_correct failed!') # run talairach registration ret = sf.system.run(f'talairach_avi --i {norm} --xfm {xfm}') if ret != 0: sf.system.fatal('talairach_avi failed!') # convert XFM to LTA mni305 = os.path.join(os.environ.get('FREESURFER_HOME'), 'average', 'mni305.cor.mgz') ret = sf.system.run(f'lta_convert --src {norm} --trg {mni305} --inxfm {xfm} --outlta {lta} --subject fsaverage --ltavox2vox') if ret != 0: sf.system.fatal('lta_convert failed!') # estimate TIV etiv = compute_etiv_from_lta(lta) shutil.rmtree(tmpdir) return etiv def print_vm_peak(): """ Print the VM peak of the running process. This is only available on linux platforms. """ if platform.system() != 'Linux': return procstat = os.path.join('/proc', str(os.getpid()), 'status') fp = open(procstat, 'r') lines = fp.readlines() for line in lines: strs = line.split() if(len(strs) < 3): continue if(strs[0] != 'VmPeak:'): continue print('vmpcma:', int(strs[1])) def build_binary_structure(connectivity, n_dims): """ Return a dilation element with provided connectivity. """ shape = [connectivity * 2 + 1] * n_dims dist = np.ones(shape) center = tuple([tuple([int(s / 2)]) for s in shape]) dist[center] = 0 dist = scipy.ndimage.distance_transform_edt(dist) struct = (dist <= connectivity) * 1 return struct # ------------------------------------------------------------------------------------------------ # Neurite Utilities - See github.com/adalca/neurite # ------------------------------------------------------------------------------------------------ def unet(nb_features, input_shape, nb_levels, conv_size, nb_labels, name='unet', prefix=None, feat_mult=1, pool_size=2, padding='same', dilation_rate_mult=1, activation='elu', use_residuals=False, final_pred_activation='softmax', nb_conv_per_level=1, layer_nb_feats=None, conv_dropout=0, batch_norm=None, input_model=None): """ Unet-style tf.keras model with an overdose of parametrization. Copied with permission from github.com/adalca/neurite. """ # naming model_name = name if prefix is None: prefix = model_name # volume size data ndims = len(input_shape) - 1 if isinstance(pool_size, int): pool_size = (pool_size,) * ndims # get encoding model enc_model = conv_enc(nb_features, input_shape, nb_levels, conv_size, name=model_name, prefix=prefix, feat_mult=feat_mult, pool_size=pool_size, padding=padding, dilation_rate_mult=dilation_rate_mult, activation=activation, use_residuals=use_residuals, nb_conv_per_level=nb_conv_per_level, layer_nb_feats=layer_nb_feats, conv_dropout=conv_dropout, batch_norm=batch_norm, input_model=input_model) # get decoder # use_skip_connections=True makes it a u-net lnf = layer_nb_feats[(nb_levels * nb_conv_per_level):] if layer_nb_feats is not None else None dec_model = conv_dec(nb_features, None, nb_levels, conv_size, nb_labels, name=model_name, prefix=prefix, feat_mult=feat_mult, pool_size=pool_size, use_skip_connections=True, padding=padding, dilation_rate_mult=dilation_rate_mult, activation=activation, use_residuals=use_residuals, final_pred_activation=final_pred_activation, nb_conv_per_level=nb_conv_per_level, batch_norm=batch_norm, layer_nb_feats=lnf, conv_dropout=conv_dropout, input_model=enc_model) final_model = dec_model return final_model def conv_enc(nb_features, input_shape, nb_levels, conv_size, name=None, prefix=None, feat_mult=1, pool_size=2, dilation_rate_mult=1, padding='same', activation='elu', layer_nb_feats=None, use_residuals=False, nb_conv_per_level=2, conv_dropout=0, batch_norm=None, input_model=None): """ Fully Convolutional Encoder. Copied with permission from github.com/adalca/neurite. """ # naming model_name = name if prefix is None: prefix = model_name # first layer: input name = '%s_input' % prefix if input_model is None: input_tensor = tf.keras.layers.Input(shape=input_shape, name=name) last_tensor = input_tensor else: input_tensor = input_model.inputs last_tensor = input_model.outputs if isinstance(last_tensor, list): last_tensor = last_tensor[0] last_tensor = tf.keras.layers.Reshape(input_shape)(last_tensor) input_shape = last_tensor.shape.as_list()[1:] # volume size data ndims = len(input_shape) - 1 input_shape = tuple(input_shape) if isinstance(pool_size, int): pool_size = (pool_size,) * ndims # prepare layers convL = getattr(tf.keras.layers, 'Conv%dD' % ndims) conv_kwargs = {'padding': padding, 'activation': activation, 'data_format': 'channels_last'} maxpool = getattr(tf.keras.layers, 'MaxPooling%dD' % ndims) # down arm: # add nb_levels of conv + ReLu + conv + ReLu. Pool after each of first nb_levels - 1 layers lfidx = 0 # level feature index for level in range(nb_levels): lvl_first_tensor = last_tensor nb_lvl_feats = np.round(nb_features * feat_mult ** level).astype(int) conv_kwargs['dilation_rate'] = dilation_rate_mult ** level for conv in range(nb_conv_per_level): # does several conv per level, max pooling applied at the end if layer_nb_feats is not None: # None or List of all the feature numbers nb_lvl_feats = layer_nb_feats[lfidx] lfidx += 1 name = '%s_conv_downarm_%d_%d' % (prefix, level, conv) if conv < (nb_conv_per_level - 1) or (not use_residuals): last_tensor = convL(nb_lvl_feats, conv_size, **conv_kwargs, name=name)(last_tensor) else: # no activation last_tensor = convL(nb_lvl_feats, conv_size, padding=padding, name=name)(last_tensor) if conv_dropout > 0: # conv dropout along feature space only name = '%s_dropout_downarm_%d_%d' % (prefix, level, conv) noise_shape = [None, *[1] * ndims, nb_lvl_feats] last_tensor = tf.keras.layers.Dropout(conv_dropout, noise_shape=noise_shape)(last_tensor) if use_residuals: convarm_layer = last_tensor # the "add" layer is the original input # However, it may not have the right number of features to be added nb_feats_in = lvl_first_tensor.get_shape()[-1] nb_feats_out = convarm_layer.get_shape()[-1] add_layer = lvl_first_tensor if nb_feats_in > 1 and nb_feats_out > 1 and (nb_feats_in != nb_feats_out): name = '%s_expand_down_merge_%d' % (prefix, level) last_tensor = convL(nb_lvl_feats, conv_size, **conv_kwargs, name=name)(lvl_first_tensor) add_layer = last_tensor if conv_dropout > 0: name = '%s_dropout_down_merge_%d_%d' % (prefix, level, conv) noise_shape = [None, *[1] * ndims, nb_lvl_feats] last_tensor = tf.keras.layers.Dropout(conv_dropout, noise_shape=noise_shape)(last_tensor) name = '%s_res_down_merge_%d' % (prefix, level) last_tensor = tf.keras.layers.add([add_layer, convarm_layer], name=name) name = '%s_res_down_merge_act_%d' % (prefix, level) last_tensor = tf.keras.layers.Activation(activation, name=name)(last_tensor) if batch_norm is not None: name = '%s_bn_down_%d' % (prefix, level) last_tensor = tf.keras.layers.BatchNormalization(axis=batch_norm, name=name)(last_tensor) # max pool if we're not at the last level if level < (nb_levels - 1): name = '%s_maxpool_%d' % (prefix, level) last_tensor = maxpool(pool_size=pool_size, name=name, padding=padding)(last_tensor) # create the model and return model = tf.keras.Model(inputs=input_tensor, outputs=[last_tensor], name=model_name) return model def conv_dec(nb_features, input_shape, nb_levels, conv_size, nb_labels, name=None, prefix=None, feat_mult=1, pool_size=2, use_skip_connections=False, padding='same', dilation_rate_mult=1, activation='elu', use_residuals=False, final_pred_activation='softmax', nb_conv_per_level=2, layer_nb_feats=None, batch_norm=None, conv_dropout=0, input_model=None): """ Fully Convolutional Decoder. Copied with permission from github.com/adalca/neurite. Parameters: ... use_skip_connections (bool): if true, turns an Enc-Dec to a U-Net. If true, input_tensor and tensors are required. It assumes a particular naming of layers. conv_enc... """ # naming model_name = name if prefix is None: prefix = model_name # if using skip connections, make sure need to use them. if use_skip_connections: assert input_model is not None, "is using skip connections, tensors dictionary is required" # first layer: input input_name = '%s_input' % prefix if input_model is None: input_tensor = tf.keras.layers.Input(shape=input_shape, name=input_name) last_tensor = input_tensor else: input_tensor = input_model.input last_tensor = input_model.output input_shape = last_tensor.shape.as_list()[1:] # vol size info ndims = len(input_shape) - 1 input_shape = tuple(input_shape) if isinstance(pool_size, int): if ndims > 1: pool_size = (pool_size,) * ndims # prepare layers convL = getattr(tf.keras.layers, 'Conv%dD' % ndims) conv_kwargs = {'padding': padding, 'activation': activation} upsample = getattr(tf.keras.layers, 'UpSampling%dD' % ndims) # up arm: # nb_levels - 1 layers of Deconvolution3D # (approx via up + conv + ReLu) + merge + conv + ReLu + conv + ReLu lfidx = 0 for level in range(nb_levels - 1): nb_lvl_feats = np.round(nb_features * feat_mult ** (nb_levels - 2 - level)).astype(int) conv_kwargs['dilation_rate'] = dilation_rate_mult ** (nb_levels - 2 - level) # upsample matching the max pooling layers size name = '%s_up_%d' % (prefix, nb_levels + level) last_tensor = upsample(size=pool_size, name=name)(last_tensor) up_tensor = last_tensor # merge layers combining previous layer if use_skip_connections: conv_name = '%s_conv_downarm_%d_%d' % (prefix, nb_levels - 2 - level, nb_conv_per_level - 1) cat_tensor = input_model.get_layer(conv_name).output name = '%s_merge_%d' % (prefix, nb_levels + level) last_tensor = tf.keras.layers.concatenate([cat_tensor, last_tensor], axis=ndims + 1, name=name) # convolution layers for conv in range(nb_conv_per_level): if layer_nb_feats is not None: nb_lvl_feats = layer_nb_feats[lfidx] lfidx += 1 name = '%s_conv_uparm_%d_%d' % (prefix, nb_levels + level, conv) if conv < (nb_conv_per_level - 1) or (not use_residuals): last_tensor = convL(nb_lvl_feats, conv_size, **conv_kwargs, name=name)(last_tensor) else: last_tensor = convL(nb_lvl_feats, conv_size, padding=padding, name=name)(last_tensor) if conv_dropout > 0: name = '%s_dropout_uparm_%d_%d' % (prefix, level, conv) noise_shape = [None, *[1] * ndims, nb_lvl_feats] last_tensor = tf.keras.layers.Dropout(conv_dropout, noise_shape=noise_shape)(last_tensor) # residual block if use_residuals: # the "add" layer is the original input # However, it may not have the right number of features to be added add_layer = up_tensor nb_feats_in = add_layer.get_shape()[-1] nb_feats_out = last_tensor.get_shape()[-1] if nb_feats_in > 1 and nb_feats_out > 1 and (nb_feats_in != nb_feats_out): name = '%s_expand_up_merge_%d' % (prefix, level) add_layer = convL(nb_lvl_feats, conv_size, **conv_kwargs, name=name)(add_layer) if conv_dropout > 0: name = '%s_dropout_up_merge_%d_%d' % (prefix, level, conv) noise_shape = [None, *[1] * ndims, nb_lvl_feats] last_tensor = tf.keras.layers.Dropout(conv_dropout, noise_shape=noise_shape)(last_tensor) name = '%s_res_up_merge_%d' % (prefix, level) last_tensor = tf.keras.layers.add([last_tensor, add_layer], name=name) name = '%s_res_up_merge_act_%d' % (prefix, level) last_tensor = tf.keras.layers.Activation(activation, name=name)(last_tensor) if batch_norm is not None: name = '%s_bn_up_%d' % (prefix, level) last_tensor = tf.keras.layers.BatchNormalization(axis=batch_norm, name=name)(last_tensor) # Compute likelyhood prediction (no activation yet) name = '%s_likelihood' % prefix last_tensor = convL(nb_labels, 1, activation=None, name=name)(last_tensor) like_tensor = last_tensor # output prediction layer # we use a softmax to compute P(L_x|I) where x is each location if final_pred_activation == 'softmax': # print("using final_pred_activation %s for %s" % (final_pred_activation, model_name)) name = '%s_prediction' % prefix softmax_lambda_fcn = lambda x: tf.keras.activations.softmax(x, axis=ndims + 1) pred_tensor = tf.keras.layers.Lambda(softmax_lambda_fcn, name=name)(last_tensor) # otherwise create a layer that does nothing. else: name = '%s_prediction' % prefix pred_tensor = tf.keras.layers.Activation('linear', name=name)(like_tensor) # create the model and retun model = tf.keras.Model(inputs=input_tensor, outputs=pred_tensor, name=model_name) return model if __name__ == '__main__': main()