% AX_PARTIAL  forward projector that allows to split the full volume into smaller pieces
%     composed tomography projector based on ASTRA toolbox 
%     can be used either for data in RAM or on GPU (automatically decided from class of volData)
%     * volume is split based on "split" parameter, 1 == no splitting
%     * Ax_partial tries to split data if GPU limits are exceeded (ie texture memory limits)
%
% projData = Ax_partial(volData, cfg, vectors,split, varargin)
%
% Inputs:
%   **volData       - array Nx x Ny x Nz of projected volume 
%   **cfg           - config structure generated by ASTRA_initialize
%   **vectors       - orientation of projections generated by ASTRA_initialize
%   **split         - 3 or 4 elements vector, [split X, split Y, split Z, split angle ]
% *optional*
%   **deformation_fields      - 3x1 cell contaning 3D arrays of local deformation of the object 
%   **GPU              - GPU id to be used for reconstruction
%   **verbose          - verbose = 0 : (default) quiet, verbose = 1: standard info , verbose = 2: debug 
%   **keep_on_GPU      - if true keep reconstructed volume on GPU to make is faster, default == false (the safe option)
%
% *returns*
%   ++projData         - projection of volData 
%
% recompile commands
%  (Linux, GCC 4.8.5)   mexcuda -outdir private  ASTRA_GPU_wrapper/ASTRA_GPU_wrapper.cu ASTRA_GPU_wrapper/util3d.cu ASTRA_GPU_wrapper/par3d_fp.cu ASTRA_GPU_wrapper/par3d_bp.cu
%  (Windows)  mexcuda -outdir private  ASTRA_GPU_wrapper\ASTRA_GPU_wrapper.cu ASTRA_GPU_wrapper\util3d.cu ASTRA_GPU_wrapper\par3d_fp.cu ASTRA_GPU_wrapper\par3d_bp.cu

%*-----------------------------------------------------------------------*
%|                                                                       |
%|  Except where otherwise noted, this work is licensed under a          |
%|  Creative Commons Attribution-NonCommercial-ShareAlike 4.0            |
%|  International (CC BY-NC-SA 4.0) license.                             |
%|                                                                       |
%|  Copyright (c) 2017 by Paul Scherrer Institute (http://www.psi.ch)    |
%|                                                                       |
%|       Author: CXS group, PSI                                          |
%*-----------------------------------------------------------------------*
% You may use this code with the following provisions:
%
% If the code is fully or partially redistributed, or rewritten in another
%   computing language this notice should be included in the redistribution.
%
% If this code, or subfunctions or parts of it, is used for research in a 
%   publication or if it is fully or partially rewritten for another 
%   computing language the authors and institution should be acknowledged 
%   in written form in the publication: “Data processing was carried out 
%   using the “cSAXS software package” developed by the CXS group,
%   Paul Scherrer Institut, Switzerland.” 
%   Variations on the latter text can be incorporated upon discussion with 
%   the CXS group if needed to more specifically reflect the use of the package 
%   for the published work.
%
% A publication that focuses on describing features, or parameters, that
%    are already existing in the code should be first discussed with the
%    authors.
%   
% This code and subroutines are part of a continuous development, they 
%    are provided “as they are” without guarantees or liability on part
%    of PSI or the authors. It is the user responsibility to ensure its 
%    proper use and the correctness of the results.


function projData = Ax_partial(volData, cfg, vectors,split, varargin)


    import utils.*
    import math.*
    
    par = inputParser;
    par.KeepUnmatched = true; 

    par.addOptional('deformation_fields', {})  % deformation_fields: 3x1 cell contaning 3D arrays of local deformation of the object 
    par.addOptional('GPU', [])   % GPUs id to be used in reconstruction
    par.addOptional('verbose', 0)   % verbose = 0 : quiet, verbose : standard info , verbose = 2: debug 
    par.addOptional('keep_on_GPU', false)   % true - keep reconstructed volume on GPU to make is faster 

    par.parse(varargin{:})
    r = par.Results;

    if isempty(r.deformation_fields); r.deformation_fields = {};  end 
    
    %% check the inputs + check memory availibility on GPU 
    assert(gpuDeviceCount>0, 'No CUDA enabled GPU availible')

    if ~(  (isa(volData, 'gpuArray') && strcmp(classUnderlying(volData), 'single')) || ...
            isa(volData,  'single') ) || ~isreal(volData)
        error('Only single precision real input array supported')
    end
    if ~isempty(r.deformation_fields) 
        assert(any(numel(r.deformation_fields) == [3,6]), 'Deformation field expected as 3x1 or 6x1 cell array')
        for i = 1:numel( r.deformation_fields)
            if ~( (isa(r.deformation_fields{i}, 'gpuArray') && strcmp(classUnderlying(r.deformation_fields{i}), 'single')) || ...
            isa(r.deformation_fields{i},  'single'))
                 error('Only single precision for deformation fields is supported')
            end
            r.deformation_fields{i} = gpuArray(r.deformation_fields{i}); % move on GPU, they are usually small 
        end
        r.deformation_fields = r.deformation_fields' ; % transpose to that array(:) results in sorted field 
    end
    

    gpu  = gpuDevice();
    if ~isempty(r.GPU) && gpu.Index ~= r.GPU(1)
        % switch and !! reset !! GPU 
        if isa(volData, 'gpuArray'), error('Switching GPUs will reset content'); end 
        gpu  = gpuDevice(r.GPU(1));
    end
   

    split = ceil(max(1,split));
    if ismatrix(volData)
        split = [split([1, min(2,end)]),1];
        assert(all(size(volData)==[cfg.iVolX,cfg.iVolY]), 'Wrong inputs size')
    else
        assert(all(size(volData)==[cfg.iVolX,cfg.iVolY,cfg.iVolZ]), 'Wrong inputs size')
    end
    
    keep_on_GPU =  isa(volData, 'gpuArray') || r.keep_on_GPU;
            
    %% backprojector that allows to split the full volume into smaller pieces
    assert(all(size(vectors,2)==12), 'Wrong "vectors" size')
    assert(~isempty(vectors), 'Empty input "vectors"')
    assert(cfg.iVolX*cfg.iVolY*cfg.iVolZ > 0, 'Inputs volume is empty'); 
    assert(cfg.iProjU*cfg.iProjV*cfg.iProjAngles > 0, 'Projections are empty'); 

    % be sure that ASTRA wrapper is feeded by doubles !!
    for i = fieldnames(cfg)'
        cfg.(i{1}) = double(cfg.(i{1})); 
    end
    vectors = double(vectors); 
    split=double(split);

    
    cfg.iProjAngles = size(vectors,1); 
    
    assert(cfg.iProjAngles > 1, 'Number of processed angles must be > 1')
    
    
    % fix if the number of projections is > 1024, or arrays are too large 
    Nproj_groups  = max(1,ceil(max([cfg.iProjAngles/1024, ...               % ASTRA constant memory limit 
        cfg.iProjU*cfg.iProjV*cfg.iProjAngles*4 / gpu.AvailableMemory, ...  % availible memory limit
        cfg.iProjU * cfg.iProjV * cfg.iProjAngles / double(intmax('int32'))])));  % maximal array size allowed by CUDA limit 
    if length(split) > 3 
        Nproj_groups = max(split(4), Nproj_groups);
    end

    if all(split == 1) && Nproj_groups == 1
        if numel(volData)*4 > 1024e6  % exceeded texture memory
            nsubVol = ceil(numel(volData)*4 / 1024e6); 
            nsubVol = 2^nextpow2(nsubVol);
%             split = ceil([sqrt(nsubVol),sqrt(nsubVol),1]);
            split = [1,1,nsubVol];
            if r.verbose>0; disp(['Volume array is larger than 1024MB, auto-splitting ', num2str(split)]); end
        else
             %% in the simple case call ASTRA_GPU_wrapper directly 
            volData = gpuArray(volData);
            % call ASTRA 
            projData = astra.ASTRA_GPU_wrapper('fp',volData, cfg, vectors,[],r.deformation_fields{:});
            clear volData
            if gpu.AvailableMemory < 4*numel(projData)
                projData = gather(projData); % prevent out of memory errors during next step
            end
            projData = astra2matlab(projData);
            if ~keep_on_GPU; projData = gather(projData); end
            return
        end
    end
    
    
    %% otherwise prepare data for split and call ASTRA_GPU_wrapper on subvolumes 

        
    if isscalar(split)
        split = split .* ones(ndims(volData),1);
    end
    assert(numel(volData)*4/prod(split) <= 1024e6, 'Volume array exceeded 1024MB, use more splitting')

    Nvol_full = size(volData);
    if numel(Nvol_full)<3
        Nvol_full(3)=1;
    end
    Nvol_sub = Nvol_full'./reshape(split(1:3),[],1);
    
    assert(all(mod(Nvol_sub,1)==0), sprintf('Volume size %ix%ix%i is not dividable by split %ix%ix%i',size(volData),split(1:3)))
    

    if ismatrix(volData)
        split(3) = 1;
        Nvol_sub(3) = 1;
    end

    cfg.iVolX = cfg.iVolX/split(1);
    cfg.iVolY = cfg.iVolY/split(2);
    cfg.iVolZ = cfg.iVolZ/split(3);
    
    cfg_orig = cfg; 
    
    if Nproj_groups > 1
        % split the projections along the angles 
        for i = 1:Nproj_groups
            ind = (1+(i-1)*ceil(cfg.iProjAngles/Nproj_groups)):i*ceil(cfg.iProjAngles/Nproj_groups);
            ind = ind(ind <= cfg.iProjAngles);  
            vectors_tmp{i} = vectors(ind,:);
            cfg_tmp{i} = cfg;
            cfg_tmp{i}.iProjAngles = length(ind);
        end
        cfg = cfg_tmp; vectors = vectors_tmp;
    else
        Nproj_groups = 1;
        cfg = {cfg};
        vectors = {vectors};
    end
    clear ind 
    

    if r.verbose > 1
        fprintf('Size of the full volume: %i %i %i\n', size(volData))
        fprintf('Size of the subvolume: %i %i %i\n', Nvol_sub)
        fprintf('Size of the one sinogram block: %i %i %i\n', cfg{1}.iProjU, cfg{1}.iProjV,  cfg{1}.iProjAngles)
    end
    
    %%!!!! note that in rare cases astra my fail if sinogram width is too small 
    
    assert(prod(Nvol_sub) * 4 <= 1024e6, 'Volume exceeded maximal size of texture 1024MB')
    
    % estimate required memory + (use only if 2x more memory is available)
    required_mem = 2*(prod(Nvol_sub)*(2+~isa(volData, 'gpuArray')) + cfg_orig.iProjU * cfg_orig.iProjV * cfg_orig.iProjAngles )*4;
    % keep projections on GPU only of there is enough memory
    keep_projections_on_GPU = Nproj_groups == 1 || gpu.AvailableMemory  > required_mem;
    
    if cfg{1}.iProjU * cfg{1}.iProjV * cfg{1}.iProjAngles > intmax('int32')
        error('Projection size exceeded maximum size allowed on GPU')
    end
    if (keep_on_GPU  || prod(split(1:3)) == 1) && numel(volData) < intmax('int32') % keep volume on GPU 
       volData = gpuArray(volData);  
    end

    inParpool = ~isempty(getCurrentTask()); 
    iter = 1;
    for m = 1:Nproj_groups
        % split angularly (solve smaller groups of angles)
        % allocate memory for the projections 
        projData{m} = gpuArray.zeros(cfg{m}.iProjU, cfg{m}.iProjV, cfg{m}.iProjAngles, 'single');
        % split into volume cubes 
        for i = 1:split(1)
            for j =  1:split(2)
                for k = 1:split(3)
                    if r.verbose > 0
                        progressbar(iter, prod(split(1:3))*Nproj_groups+1, 20);
                    end
                    pos = [i,j,k];
                     for n = 1:3
                        ind{n} = (1+(pos(n)-1)*Nvol_sub(n)):(pos(n)*Nvol_sub(n));
                        %% find optimal shift of the subvolume
                        if mod(split(n),2)==1  %% odd 
                            shift(n) = (pos(n) - ceil(split(n)/2))*Nvol_sub(n);
                        else
                            shift(n) = (pos(n) - ceil(split(n)/2)-1/2)*Nvol_sub(n);
                        end
                     end
                     
                     % extract subvolume to be processed 
                    if any(split(1:3)~=1) && isa(volData, 'gpuArray')
                        vol_small = volData(ind{:});  % take only small subvolume
                    elseif any(split(1:3)~=1)
                        vol_small = zeros(Nvol_sub','single'); 
                        utils.get_from_3D_projection(vol_small,volData,[ind{1}(1),ind{2}(1)]-1,ind{3});
                    else
                        vol_small = volData;  % avoid data copying of possible 
                    end
                    vol_small = gpuArray(vol_small); 
                    
                    % split deformation field for nonrigid tomography 
                    if ~isempty(r.deformation_fields)
                        for ii = 1:3
                            N_deform = size(r.deformation_fields{ii}) ./ reshape(split(1:3),[],1)';
                            for jj = 1:3
                                ind_def{jj} = linspace(1+(pos(jj)-1)*N_deform(jj), pos(jj)*N_deform(jj), size(r.deformation_fields{ii},jj));
                            end
                            [X,Y,Z]= meshgrid(ind_def{:});
                            deformation_fields_sub{ii} = interp3(r.deformation_fields{ii},X,Y,Z);
                        end
                    else
                        deformation_fields_sub = {};
                    end
                                            
                    vec = vectors{m};
                    vec(:,4:6) = bsxfun(@minus, vec(:,4:6), shift);
                        
                    req_mem = 2*numel(vol_small)*4 ; 
                    if gpu.AvailableMemory < req_mem
                        !nvidia-smi 
                        whos
                        error('Too low GPU memory, avail: %3.2gGB / req: %3.2gGB, GPU %i/%i, projection group %i/%i, keep_proj_on_GPU=%i', gpu.AvailableMemory/1e9,req_mem/1e9, gpu.Index, gpuDeviceCount, m , Nproj_groups, keep_projections_on_GPU)
                    end

                    try
                        % avoid memory allocation, write directly to projData{m} -> no output arguments are needed 
                        astra.ASTRA_GPU_wrapper('fp',vol_small, cfg{m}, vec,projData{m}, deformation_fields_sub{:});
                        vol_small = [];  % soft mem clean 
                    catch err
                        if strcmpi(err.identifier,'parallel:gpu:array:OOM')
                            warning('Out of memory on GPU %i, try reset GPU or split the array onto smaller blocks', gpu.Index)
                            gpuDevice
                            reset(gpuDevice)
                        end
                        
                        rethrow(err)
                    end

                    iter =  iter+1;
                    if r.verbose>0
                        progressbar(iter, prod(split(1:3))*Nproj_groups+1, 20);
                    end
                end
            end
        end
        if ~keep_projections_on_GPU
            projData{m} = gather(projData{m}); 
        end
    end
    clear  volData vol_small 
    
   
    % permute / concatenate 
    if gpu.AvailableMemory < 4*numel(projData{1})*max(2,Nproj_groups)
        projData = gather_all(projData); 
    end
    
    projData =  astra2matlab(projData);
    
    if gpu.AvailableMemory < 8*numel(projData{1})*Nproj_groups
        projData = gather_all(projData); 
    end
    % concatenate the projected data align the angular (3rd) axis 
    projData = merge_projections(projData); 
    
    if ~keep_on_GPU 
        projData = gather(projData);
    elseif numel(projData) < intmax('int32') && gpu.AvailableMemory < 4*numel(projData)
        % return to GPU if requested by 'keep_on_GPU' parameter 
        projData = gpuArray(projData);
    end
     
end

function x = gather_all(x)
    for i = 1:length(x)
       x{i} = gather(x{i});  
    end
end


function projData = merge_projections(projData_blocks)
    Nblocks = length(projData_blocks); 
    if Nblocks > 1
        if isa(projData_blocks{1}, 'gpuArray')
            projData = cat(3, projData_blocks{:});
        else
            % faster and more memory efficient version 
            proj_size = [size(projData_blocks{1},1),size(projData_blocks{1},2),sum(cellfun(@(x)size(x,3), projData_blocks))]; 
            for ii = 1:10
                try
                    projData = zeros(proj_size, 'single'); 
                    break
                catch err
                end
                pause(1)
            end
            if ii == 10
                warning('Unsufficient memory to allocate %3.2gGB RAM', prod(proj_size)*4/1e9)
                utils.check_available_memory
                rethrow(err)
            end
            offset = 0; 
            for ii = 1:Nblocks
                tomo.set_to_array(projData, projData_blocks{ii}, offset);
                offset = offset + size(projData_blocks{ii},3); 
            end
        end
    else
        projData = projData_blocks{1};
    end
end


