% AX_PARTIAL  backprojector that allows to split the full volume into smaller pieces
%     composed tomography back-projector based on ASTRA toolbox 
%     can be used either for data in RAM or on GPU (automatically decided from class of volData)
%     * volume is split based on "split" parameter, 1 == no splitting
%     * Ax_partial tries to split data if GPU limits are exceeded (ie texture memory limits)
%
% vol_full = Atx_partial(projData, cfg, vectors,split,varargin)
%
% Inputs:
%   **projData       - array Nx x Ny x Nz of projected volume 
%   **cfg           - config structure generated by ASTRA_initialize
%   **vectors       - orientation of projections generated by ASTRA_initialize
%   **split         - 3 or 4 elements vector, [split X, split Y, split Z, split angle ]
% *optional*
%   **deformation_fields      - 3x1 cell contaning 3D arrays of local deformation of the object 
%   **GPU              - GPU id to be used for reconstruction
%   **verbose          - verbose = 0 : (default) quiet, verbose = 1: standard info , verbose = 2: debug 
%   **keep_on_GPU      - if true keep reconstructed volume on GPU to make is faster, default == false (the safe option)
%
% *returns*
%   ++vol_full         - projection of volData 
%
% recompile commands
%  (Linux, GCC 4.8.5)   mexcuda -outdir private  ASTRA_GPU_wrapper/ASTRA_GPU_wrapper.cu ASTRA_GPU_wrapper/util3d.cu ASTRA_GPU_wrapper/par3d_fp.cu ASTRA_GPU_wrapper/par3d_bp.cu
%  (Windows)  mexcuda -outdir private  ASTRA_GPU_wrapper\ASTRA_GPU_wrapper.cu ASTRA_GPU_wrapper\util3d.cu ASTRA_GPU_wrapper\par3d_fp.cu ASTRA_GPU_wrapper\par3d_bp.cu

    
%*-----------------------------------------------------------------------*
%|                                                                       |
%|  Except where otherwise noted, this work is licensed under a          |
%|  Creative Commons Attribution-NonCommercial-ShareAlike 4.0            |
%|  International (CC BY-NC-SA 4.0) license.                             |
%|                                                                       |
%|  Copyright (c) 2017 by Paul Scherrer Institute (http://www.psi.ch)    |
%|                                                                       |
%|       Author: CXS group, PSI                                          |
%*-----------------------------------------------------------------------*
% You may use this code with the following provisions:
%
% If the code is fully or partially redistributed, or rewritten in another
%   computing language this notice should be included in the redistribution.
%
% If this code, or subfunctions or parts of it, is used for research in a 
%   publication or if it is fully or partially rewritten for another 
%   computing language the authors and institution should be acknowledged 
%   in written form in the publication: “Data processing was carried out 
%   using the “cSAXS software package” developed by the CXS group,
%   Paul Scherrer Institut, Switzerland.” 
%   Variations on the latter text can be incorporated upon discussion with 
%   the CXS group if needed to more specifically reflect the use of the package 
%   for the published work.
%
% A publication that focuses on describing features, or parameters, that
%    are already existing in the code should be first discussed with the
%    authors.
%   
% This code and subroutines are part of a continuous development, they 
%    are provided “as they are” without guarantees or liability on part
%    of PSI or the authors. It is the user responsibility to ensure its 
%    proper use and the correctness of the results.


function vol_full = Atx_partial(projData, cfg, vectors,split,varargin)


    import utils.*
    import math.*
    
    par = inputParser;
    par.KeepUnmatched = true; 

    par.addOptional('deformation_fields', {})  % deformation_fields: 3x1 cell contaning 3D arrays of local deformation of the object 
    par.addOptional('GPU', [])   % GPUs id to be used in reconstruction
    par.addOptional('keep_on_GPU', false)   % true - keep reconstructed volume on GPU to make is faster 
    par.addOptional('verbose', 0)   % 

    par.parse(varargin{:})
    r = par.Results;
    
    if isempty(r.deformation_fields); r.deformation_fields = {};  end 
    
    %% check the inputs + check memory availibility on GPU 
    assert(gpuDeviceCount>0, 'No CUDA enabled GPU availible')

    if ~(  (isa(projData, 'gpuArray') && strcmp(classUnderlying(projData), 'single')) || ...
            isa(projData,  'single') ) || ~isreal(projData)
        error('Only single precision real input array supported')
    end
    
    if ~isempty(r.deformation_fields) 
        assert(any(numel(r.deformation_fields) == [3,6]), 'Deformation field expected as 3x1 or 6x1 cell array')

        for i = 1:numel(r.deformation_fields)
            if ~( (isa(r.deformation_fields{i}, 'gpuArray') && strcmp(classUnderlying(r.deformation_fields{i}), 'single')) || ...
            isa(r.deformation_fields{i},  'single'))
                 error('Only single precision for deformation fields is supported')
            end
            r.deformation_fields{i} = gpuArray(r.deformation_fields{i}); % move on GPU, they are usually small 
        end
        r.deformation_fields = r.deformation_fields' ; % transpose to that array(:) results in sorted field 
    end


    projSize = size(projData); 
    cfg.iProjAngles = size(vectors,1);  % ignore value in config 
    
    gpu  = gpuDevice();
    if ~isempty(r.GPU) && gpu.Index ~= r.GPU(1)
        % switch and !! reset !! GPU 
        if  isa(projData, 'gpuArray'), error('Switching GPUs will reset content'); end 
        gpu  = gpuDevice(r.GPU(1));
    end

    
    assert(cfg.iVolX*cfg.iVolY*cfg.iVolZ > 0, 'Volume is empty'); 
    assert(numel(projSize) > 0, 'Projections are empty'); 
        
    % input parameters check
    if ismatrix(projData); projSize = [projSize,1]; end
    assert(max(cfg.iProjU, cfg.iProjV) <= 4096, 'Sinogram exceed maximal size allowed by GPU (4096)')
    assert(all(projSize==[cfg.iProjV,cfg.iProjU,cfg.iProjAngles]), 'Wrong inputs size')
    assert(all(size(vectors)==[cfg.iProjAngles,12]), 'Wrong vectors size')

    split_projections = cfg.iProjAngles > 1024 || ... 
       numel(projData)*4 > 1024e6 || ... ;  % Data array is to large, try splitting
       (length(split) > 3 && split(4) > 1);  % split contains also angular split 
        
    % be sure that ASTRA wrapper is feeded by doubles !!
    for i = fieldnames(cfg)'
        cfg.(i{1}) = double(cfg.(i{1})); 
    end
    vectors = double(vectors); 
    split=double(split);
    
    if all(split == 1) &&  ~split_projections
        assert(cfg.iVolX*cfg.iVolY*cfg.iVolZ*4 < min(4*double(intmax('int32')),gpu.AvailableMemory), 'Volume array is too large for GPU, use Atx_sup_partial'); 
        %% in the simplest case call ASTRA_GPU_wrapper directly 
        if isa(projData, 'gpuArray'), r.keep_on_GPU = true; end  % assume that the data should stay on GPU if provided 
        projData = matlab2astra(gpuArray(projData));
        vol_full = astra.ASTRA_GPU_wrapper('bp',projData, cfg, vectors,[],r.deformation_fields{:});
        if ~r.keep_on_GPU; vol_full = gather(vol_full); end 
        return
    end
    
    %% otherwise prepare data for split and call ASTRA_GPU_wrapper on subvolumes 
    
    split = ceil(max(1,split));
    if isscalar(split)
        split = split .* ones(3,1);
    end
        
    if length(split) < 4 
        split(4) = 1; 
    end
    
    
    %% PREPARE VOLUME %% 

    Npix_full = [cfg.iVolX,cfg.iVolY,cfg.iVolZ];
    Npix_small = Npix_full(:)./reshape(split(1:3),[],1);
    
    cfg.iVolX = Npix_small(1);
    cfg.iVolY = Npix_small(2);
    cfg.iVolZ = Npix_small(3);
    
    assert( all(mod(Npix_small,1)==0), sprintf('Volume array cannot be divided to %i %i %i cubes', split))
    assert(prod(Npix_small)*4 < min(4*double(intmax('int32')),gpu.AvailableMemory), 'Volume array is too large for GPU, use Atx_sup_partial'); 


    req_vol_memory = 4*(cfg.iVolX*cfg.iVolY*cfg.iVolZ) * (1+any(split(1:3)>1)); 
    
    keep_volume_on_GPU =  isa(projData, 'gpuArray') ||...
                        (gpu.AvailableMemory * 0.5 > req_vol_memory) && ...
                        ( (cfg.iVolX*cfg.iVolY*cfg.iVolZ) < intmax('int32')  );
    
    % preallocate large array for results 
    if keep_volume_on_GPU || prod(split(1:3)) == 1
        % transfer to GPU now 
        vol_full = gpuArray.zeros(Npix_full, 'single');
    else
        % keep the volume in RAM 
        vol_full = zeros(Npix_full, 'single');
    end
    
    %% PREPARE PROJECTIONS %% 

    % fix if the number of projections is > 1024, or arrays are too large 
    Nproj_groups  = max(split(4),ceil(max([cfg.iProjAngles/1024, ...
        cfg.iProjU*cfg.iProjV*cfg.iProjAngles*4 / min(1024e6,gpu.AvailableMemory-1.024e9), ...
        cfg.iProjU * cfg.iProjV * cfg.iProjAngles / double(intmax('int32'))])));
    
    % avoid some rounding issues during splitting 
    Nproj_groups = ceil(cfg.iProjAngles / floor(cfg.iProjAngles/Nproj_groups)); 


    if r.verbose > 1
        fprintf('Size of the full volume: %i %i %i\n', Npix_full)
        fprintf('Size of the subvolume: %i %i %i\n', Npix_small)
        fprintf('Size of the data: %i %i %i\n', size(projData))
        fprintf('Free GPU memory: %3.2g%%\n',  gpu.AvailableMemory/gpu.TotalMemory*100)
    end
        
    if  Nproj_groups > 1
        % fix if the number of projections is > 1024 (limitation of the ASTRA code )
        % also if the dataset is too large (>1024MB), do automatic splitting along angles
        
        for i = 1:Nproj_groups
            ind = (1+(i-1)*ceil(cfg.iProjAngles/Nproj_groups)):i*ceil(cfg.iProjAngles/Nproj_groups);
            ind = ind(ind <= cfg.iProjAngles);  
            proj_ind{i} = ind;
            vectors_tmp{i} = vectors(ind,:);
            cfg_tmp{i} = cfg;
            cfg_tmp{i}.iProjAngles = length(ind);
        end
        cfg = cfg_tmp; vectors = vectors_tmp;
        clear projData_tmp 
    else
        Nproj_groups = 1;
        cfg = {cfg};
        vectors = {vectors};
    end
    
    if ~isempty(r.deformation_fields) && any(split(1:3) > 1)
        error('Deformation field splitting not implemented')
    end
       

    if gpu.AvailableMemory > 4*(numel(projData)+ ...
                             cfg{1}.iProjU* cfg{1}.iProjV*cfg{1}.iProjAngles * (Nproj_groups>1) ...
                             +any(split>1)*prod(Npix_small)*(2 + ~isa(vol_full, 'gpuArray'))) && ...
                             numel(projData) < intmax('int32') 
         projData = gpuArray(projData); % move small blocks directly on GPU 
    end
    
    inParpool = ~isempty(getCurrentTask()); 
    iter = 1;
    for k = 1:Nproj_groups
        if Nproj_groups > 1
            if length(proj_ind{k}) > 1 && ~inParpool && ~isa(projData, 'gpuArray') && exist('+tomo/get_from_array.m', 'file')
                % use custom made MEX function from +tomo package, usually
                % faster but use a lot of CPU
                projData_small = tomo.get_from_array(projData, [], proj_ind{k}); % load subblock from projections 
            else
                projData_small = projData(:,:,proj_ind{k});  % split the projections if needed 
            end
        else
            projData_small = projData;
        end
        if gpu.AvailableMemory < 2*4*numel(projData_small)
            % in case of low GPU memory, transpose data in RAM 
            projData_small = gpuArray(matlab2astra(projData_small));  % transfer projections to GPU if not there yet 
        else
            % move subblocks of the data on GPU 
            projData_small = gpuArray(projData_small); 
            projData_small = matlab2astra(projData_small);  
        end   
        if numel(projData_small) * 4 > 1024e6
            error('Data exceeded maximal size of texture memory 1024MB, Increase "split" to reduce the projection size')
            
        end
        
        if gpu.AvailableMemory  < prod(Npix_small)*4 && prod(split(1:3)) ~= 1
            % memory needed to make CUDA array for texture memory 
            pause(0.1)
            !nvidia-smi 
            whos 
            error('Too low GPU memory, avail: %3.2gGB / req: %3.2gGB, GPU %i/%i', gpu.AvailableMemory/1e9,prod(Npix_small)*4/1e9, gpu.Index, gpuDeviceCount)
        end
        for z = 1:split(3)
            for x =  1:split(1)
                for y =1:split(2)
                    if r.verbose > 0
                        progressbar(iter, prod(split(1:3))*Nproj_groups+1, 20);
                    end
                    pos = [x,y,z];
                    for n = 1:3
                        % find optimal shift of the subvolume
                        if mod(split(n),2)==1  %% odd 
                            shift(n) = (pos(n) - ceil(split(n)/2))*Npix_small(n);
                        else
                            shift(n) = (pos(n) - ceil(split(n)/2)-1/2)*Npix_small(n);
                        end
                    end
                     
                    vectors_tmp = vectors{k};
                    vectors_tmp(:,4:6) = bsxfun(@minus, vectors_tmp(:,4:6), shift);
                    if prod(split(1:3))== 1
                        req_mem = numel(projData_small)*4 ; 
                    else
                        req_mem = 2*numel(projData_small)*4 ; 
                    end
                    if gpu.AvailableMemory < req_mem
                        error('Too low GPU memory, avail: %3.2gGB / req: %3.2gGB, GPU %i/%i', gpu.AvailableMemory/1e9,req_mem/1e9, gpu.Index, gpuDeviceCount)
                    end
                    try
                        if prod(split(1:3))== 1
                            % no splitting, vol_small == vol_full -> write
                            % the backprojection directly to vol_full
                            % without copying -> no output arguments are needed 
                            astra.ASTRA_GPU_wrapper('bp', projData_small, cfg{k}, vectors_tmp,vol_full,r.deformation_fields{:});
                        else
                            vol_small = astra.ASTRA_GPU_wrapper('bp', projData_small, cfg{k}, vectors_tmp,[],r.deformation_fields{:});
                            % if some volume split is needed, add the subvolume
                            % to the full volume 
                            if ~keep_volume_on_GPU 
                                vol_small = gather(vol_small);  % return to RAM
                            end
                            if keep_volume_on_GPU || isa(vol_full, 'gpuArray') || ~exist('+tomo/add_to_3D_volume.m', 'file')
                                % return results to the large array 
                                % use matlab to do the copying on GPU 
                                vol_full =add_to_3D(vol_full, vol_small,([x,y,z]-1).*Npix_small');
                            else
                                % otherwise use paralelized CPU code
                                tomo.add_to_3D_volume(vol_full,vol_small, ([x,y,z]-1).*Npix_small', true); 
                            end
                        end
                    catch err
                        if strcmpi(err.identifier,'parallel:gpu:array:OOM')
                            warning('Out of memory on GPU %i, try reset GPU or split the array onto smaller blocks', gpu.Index)
                            gpuDevice
                        end
                        reset(gpuDevice)
                        rethrow(err)
                    end

                    iter = iter+1;
                    if r.verbose > 0
                        progressbar(iter, prod(split(1:3))*Nproj_groups+1, 20);
                    end
                end
            end
        end
    end
    
    if ~r.keep_on_GPU
        vol_full = gather(vol_full); 
    end 

    
end
